Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Database migration script for the chat agent application.""" | |
| import os | |
| import sys | |
| import psycopg2 | |
| from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT | |
| import argparse | |
| from pathlib import Path | |
| # Add the parent directory to the path so we can import config | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from config import config | |
| class DatabaseMigrator: | |
| """Handles database migrations for the chat agent.""" | |
| def __init__(self, database_url=None, config_name='development'): | |
| """Initialize the migrator with database connection.""" | |
| if database_url: | |
| self.database_url = database_url | |
| else: | |
| app_config = config[config_name] | |
| self.database_url = app_config.SQLALCHEMY_DATABASE_URI | |
| self.migrations_dir = Path(__file__).parent | |
| def get_connection(self, autocommit=True): | |
| """Get a database connection.""" | |
| conn = psycopg2.connect(self.database_url) | |
| if autocommit: | |
| conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) | |
| return conn | |
| def create_database_if_not_exists(self): | |
| """Create the database if it doesn't exist.""" | |
| # Parse the database URL to get database name | |
| from urllib.parse import urlparse | |
| parsed = urlparse(self.database_url) | |
| db_name = parsed.path[1:] # Remove leading slash | |
| # Connect to postgres database to create our target database | |
| postgres_url = self.database_url.replace(f'/{db_name}', '/postgres') | |
| try: | |
| conn = psycopg2.connect(postgres_url) | |
| conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) | |
| cursor = conn.cursor() | |
| # Check if database exists | |
| cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,)) | |
| exists = cursor.fetchone() | |
| if not exists: | |
| print(f"Creating database: {db_name}") | |
| cursor.execute(f'CREATE DATABASE "{db_name}"') | |
| print(f"Database {db_name} created successfully") | |
| else: | |
| print(f"Database {db_name} already exists") | |
| cursor.close() | |
| conn.close() | |
| except psycopg2.Error as e: | |
| print(f"Error creating database: {e}") | |
| raise | |
| def create_migrations_table(self): | |
| """Create the migrations tracking table.""" | |
| conn = self.get_connection() | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS schema_migrations ( | |
| version VARCHAR(255) PRIMARY KEY, | |
| applied_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| cursor.close() | |
| conn.close() | |
| print("Migrations table created/verified") | |
| def get_applied_migrations(self): | |
| """Get list of applied migrations.""" | |
| conn = self.get_connection() | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute("SELECT version FROM schema_migrations ORDER BY version") | |
| applied = [row[0] for row in cursor.fetchall()] | |
| except psycopg2.Error: | |
| # Table doesn't exist yet | |
| applied = [] | |
| cursor.close() | |
| conn.close() | |
| return applied | |
| def get_available_migrations(self): | |
| """Get list of available migration files.""" | |
| migrations = [] | |
| for file_path in sorted(self.migrations_dir.glob("*.sql")): | |
| if file_path.name != "migrate.py": | |
| version = file_path.stem | |
| migrations.append((version, file_path)) | |
| return migrations | |
| def apply_migration(self, version, file_path): | |
| """Apply a single migration.""" | |
| print(f"Applying migration: {version}") | |
| conn = self.get_connection(autocommit=False) | |
| cursor = conn.cursor() | |
| try: | |
| # Read and execute the migration file | |
| with open(file_path, 'r') as f: | |
| migration_sql = f.read() | |
| cursor.execute(migration_sql) | |
| # Record the migration as applied | |
| cursor.execute( | |
| "INSERT INTO schema_migrations (version) VALUES (%s)", | |
| (version,) | |
| ) | |
| conn.commit() | |
| print(f"Migration {version} applied successfully") | |
| except psycopg2.Error as e: | |
| conn.rollback() | |
| print(f"Error applying migration {version}: {e}") | |
| raise | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def migrate(self, target_version=None): | |
| """Run all pending migrations.""" | |
| print("Starting database migration...") | |
| # Create database if it doesn't exist | |
| self.create_database_if_not_exists() | |
| # Create migrations table | |
| self.create_migrations_table() | |
| # Get applied and available migrations | |
| applied = set(self.get_applied_migrations()) | |
| available = self.get_available_migrations() | |
| # Filter migrations to apply | |
| to_apply = [] | |
| for version, file_path in available: | |
| if version not in applied: | |
| if target_version is None or version <= target_version: | |
| to_apply.append((version, file_path)) | |
| if not to_apply: | |
| print("No pending migrations to apply") | |
| return | |
| # Apply migrations | |
| for version, file_path in to_apply: | |
| self.apply_migration(version, file_path) | |
| print(f"Migration completed. Applied {len(to_apply)} migrations.") | |
| def status(self): | |
| """Show migration status.""" | |
| try: | |
| applied = set(self.get_applied_migrations()) | |
| available = self.get_available_migrations() | |
| print("Migration Status:") | |
| print("-" * 50) | |
| for version, file_path in available: | |
| status = "APPLIED" if version in applied else "PENDING" | |
| print(f"{version:<30} {status}") | |
| pending_count = len([v for v, _ in available if v not in applied]) | |
| print(f"\nTotal migrations: {len(available)}") | |
| print(f"Applied: {len(applied)}") | |
| print(f"Pending: {pending_count}") | |
| except Exception as e: | |
| print(f"Error checking migration status: {e}") | |
| def main(): | |
| """Main CLI interface for migrations.""" | |
| parser = argparse.ArgumentParser(description="Database migration tool") | |
| parser.add_argument( | |
| "command", | |
| choices=["migrate", "status"], | |
| help="Migration command to run" | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| default="development", | |
| choices=["development", "production", "testing"], | |
| help="Configuration environment" | |
| ) | |
| parser.add_argument( | |
| "--database-url", | |
| help="Database URL (overrides config)" | |
| ) | |
| parser.add_argument( | |
| "--target", | |
| help="Target migration version" | |
| ) | |
| args = parser.parse_args() | |
| # Create migrator | |
| migrator = DatabaseMigrator( | |
| database_url=args.database_url, | |
| config_name=args.config | |
| ) | |
| # Run command | |
| if args.command == "migrate": | |
| migrator.migrate(target_version=args.target) | |
| elif args.command == "status": | |
| migrator.status() | |
| if __name__ == "__main__": | |
| main() |