WebashalarForML's picture
Upload 178 files
330b6e4 verified
#!/usr/bin/env python3
"""Database migration script for the chat agent application."""
import os
import sys
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
import argparse
from pathlib import Path
# Add the parent directory to the path so we can import config
sys.path.append(str(Path(__file__).parent.parent))
from config import config
class DatabaseMigrator:
"""Handles database migrations for the chat agent."""
def __init__(self, database_url=None, config_name='development'):
"""Initialize the migrator with database connection."""
if database_url:
self.database_url = database_url
else:
app_config = config[config_name]
self.database_url = app_config.SQLALCHEMY_DATABASE_URI
self.migrations_dir = Path(__file__).parent
def get_connection(self, autocommit=True):
"""Get a database connection."""
conn = psycopg2.connect(self.database_url)
if autocommit:
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
return conn
def create_database_if_not_exists(self):
"""Create the database if it doesn't exist."""
# Parse the database URL to get database name
from urllib.parse import urlparse
parsed = urlparse(self.database_url)
db_name = parsed.path[1:] # Remove leading slash
# Connect to postgres database to create our target database
postgres_url = self.database_url.replace(f'/{db_name}', '/postgres')
try:
conn = psycopg2.connect(postgres_url)
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
cursor = conn.cursor()
# Check if database exists
cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
exists = cursor.fetchone()
if not exists:
print(f"Creating database: {db_name}")
cursor.execute(f'CREATE DATABASE "{db_name}"')
print(f"Database {db_name} created successfully")
else:
print(f"Database {db_name} already exists")
cursor.close()
conn.close()
except psycopg2.Error as e:
print(f"Error creating database: {e}")
raise
def create_migrations_table(self):
"""Create the migrations tracking table."""
conn = self.get_connection()
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
applied_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.close()
conn.close()
print("Migrations table created/verified")
def get_applied_migrations(self):
"""Get list of applied migrations."""
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT version FROM schema_migrations ORDER BY version")
applied = [row[0] for row in cursor.fetchall()]
except psycopg2.Error:
# Table doesn't exist yet
applied = []
cursor.close()
conn.close()
return applied
def get_available_migrations(self):
"""Get list of available migration files."""
migrations = []
for file_path in sorted(self.migrations_dir.glob("*.sql")):
if file_path.name != "migrate.py":
version = file_path.stem
migrations.append((version, file_path))
return migrations
def apply_migration(self, version, file_path):
"""Apply a single migration."""
print(f"Applying migration: {version}")
conn = self.get_connection(autocommit=False)
cursor = conn.cursor()
try:
# Read and execute the migration file
with open(file_path, 'r') as f:
migration_sql = f.read()
cursor.execute(migration_sql)
# Record the migration as applied
cursor.execute(
"INSERT INTO schema_migrations (version) VALUES (%s)",
(version,)
)
conn.commit()
print(f"Migration {version} applied successfully")
except psycopg2.Error as e:
conn.rollback()
print(f"Error applying migration {version}: {e}")
raise
finally:
cursor.close()
conn.close()
def migrate(self, target_version=None):
"""Run all pending migrations."""
print("Starting database migration...")
# Create database if it doesn't exist
self.create_database_if_not_exists()
# Create migrations table
self.create_migrations_table()
# Get applied and available migrations
applied = set(self.get_applied_migrations())
available = self.get_available_migrations()
# Filter migrations to apply
to_apply = []
for version, file_path in available:
if version not in applied:
if target_version is None or version <= target_version:
to_apply.append((version, file_path))
if not to_apply:
print("No pending migrations to apply")
return
# Apply migrations
for version, file_path in to_apply:
self.apply_migration(version, file_path)
print(f"Migration completed. Applied {len(to_apply)} migrations.")
def status(self):
"""Show migration status."""
try:
applied = set(self.get_applied_migrations())
available = self.get_available_migrations()
print("Migration Status:")
print("-" * 50)
for version, file_path in available:
status = "APPLIED" if version in applied else "PENDING"
print(f"{version:<30} {status}")
pending_count = len([v for v, _ in available if v not in applied])
print(f"\nTotal migrations: {len(available)}")
print(f"Applied: {len(applied)}")
print(f"Pending: {pending_count}")
except Exception as e:
print(f"Error checking migration status: {e}")
def main():
"""Main CLI interface for migrations."""
parser = argparse.ArgumentParser(description="Database migration tool")
parser.add_argument(
"command",
choices=["migrate", "status"],
help="Migration command to run"
)
parser.add_argument(
"--config",
default="development",
choices=["development", "production", "testing"],
help="Configuration environment"
)
parser.add_argument(
"--database-url",
help="Database URL (overrides config)"
)
parser.add_argument(
"--target",
help="Target migration version"
)
args = parser.parse_args()
# Create migrator
migrator = DatabaseMigrator(
database_url=args.database_url,
config_name=args.config
)
# Run command
if args.command == "migrate":
migrator.migrate(target_version=args.target)
elif args.command == "status":
migrator.status()
if __name__ == "__main__":
main()