File size: 7,902 Bytes
330b6e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/env python3
"""Database migration script for the chat agent application."""

import os
import sys
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
import argparse
from pathlib import Path

# Add the parent directory to the path so we can import config
sys.path.append(str(Path(__file__).parent.parent))

from config import config


class DatabaseMigrator:
    """Handles database migrations for the chat agent."""
    
    def __init__(self, database_url=None, config_name='development'):
        """Initialize the migrator with database connection."""
        if database_url:
            self.database_url = database_url
        else:
            app_config = config[config_name]
            self.database_url = app_config.SQLALCHEMY_DATABASE_URI
        
        self.migrations_dir = Path(__file__).parent
        
    def get_connection(self, autocommit=True):
        """Get a database connection."""
        conn = psycopg2.connect(self.database_url)
        if autocommit:
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
        return conn
    
    def create_database_if_not_exists(self):
        """Create the database if it doesn't exist."""
        # Parse the database URL to get database name
        from urllib.parse import urlparse
        parsed = urlparse(self.database_url)
        db_name = parsed.path[1:]  # Remove leading slash
        
        # Connect to postgres database to create our target database
        postgres_url = self.database_url.replace(f'/{db_name}', '/postgres')
        
        try:
            conn = psycopg2.connect(postgres_url)
            conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
            cursor = conn.cursor()
            
            # Check if database exists
            cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
            exists = cursor.fetchone()
            
            if not exists:
                print(f"Creating database: {db_name}")
                cursor.execute(f'CREATE DATABASE "{db_name}"')
                print(f"Database {db_name} created successfully")
            else:
                print(f"Database {db_name} already exists")
                
            cursor.close()
            conn.close()
            
        except psycopg2.Error as e:
            print(f"Error creating database: {e}")
            raise
    
    def create_migrations_table(self):
        """Create the migrations tracking table."""
        conn = self.get_connection()
        cursor = conn.cursor()
        
        cursor.execute("""

            CREATE TABLE IF NOT EXISTS schema_migrations (

                version VARCHAR(255) PRIMARY KEY,

                applied_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP

            )

        """)
        
        cursor.close()
        conn.close()
        print("Migrations table created/verified")
    
    def get_applied_migrations(self):
        """Get list of applied migrations."""
        conn = self.get_connection()
        cursor = conn.cursor()
        
        try:
            cursor.execute("SELECT version FROM schema_migrations ORDER BY version")
            applied = [row[0] for row in cursor.fetchall()]
        except psycopg2.Error:
            # Table doesn't exist yet
            applied = []
        
        cursor.close()
        conn.close()
        return applied
    
    def get_available_migrations(self):
        """Get list of available migration files."""
        migrations = []
        for file_path in sorted(self.migrations_dir.glob("*.sql")):
            if file_path.name != "migrate.py":
                version = file_path.stem
                migrations.append((version, file_path))
        return migrations
    
    def apply_migration(self, version, file_path):
        """Apply a single migration."""
        print(f"Applying migration: {version}")
        
        conn = self.get_connection(autocommit=False)
        cursor = conn.cursor()
        
        try:
            # Read and execute the migration file
            with open(file_path, 'r') as f:
                migration_sql = f.read()
            
            cursor.execute(migration_sql)
            
            # Record the migration as applied
            cursor.execute(
                "INSERT INTO schema_migrations (version) VALUES (%s)",
                (version,)
            )
            
            conn.commit()
            print(f"Migration {version} applied successfully")
            
        except psycopg2.Error as e:
            conn.rollback()
            print(f"Error applying migration {version}: {e}")
            raise
        finally:
            cursor.close()
            conn.close()
    
    def migrate(self, target_version=None):
        """Run all pending migrations."""
        print("Starting database migration...")
        
        # Create database if it doesn't exist
        self.create_database_if_not_exists()
        
        # Create migrations table
        self.create_migrations_table()
        
        # Get applied and available migrations
        applied = set(self.get_applied_migrations())
        available = self.get_available_migrations()
        
        # Filter migrations to apply
        to_apply = []
        for version, file_path in available:
            if version not in applied:
                if target_version is None or version <= target_version:
                    to_apply.append((version, file_path))
        
        if not to_apply:
            print("No pending migrations to apply")
            return
        
        # Apply migrations
        for version, file_path in to_apply:
            self.apply_migration(version, file_path)
        
        print(f"Migration completed. Applied {len(to_apply)} migrations.")
    
    def status(self):
        """Show migration status."""
        try:
            applied = set(self.get_applied_migrations())
            available = self.get_available_migrations()
            
            print("Migration Status:")
            print("-" * 50)
            
            for version, file_path in available:
                status = "APPLIED" if version in applied else "PENDING"
                print(f"{version:<30} {status}")
                
            pending_count = len([v for v, _ in available if v not in applied])
            print(f"\nTotal migrations: {len(available)}")
            print(f"Applied: {len(applied)}")
            print(f"Pending: {pending_count}")
            
        except Exception as e:
            print(f"Error checking migration status: {e}")


def main():
    """Main CLI interface for migrations."""
    parser = argparse.ArgumentParser(description="Database migration tool")
    parser.add_argument(
        "command", 
        choices=["migrate", "status"], 
        help="Migration command to run"
    )
    parser.add_argument(
        "--config", 
        default="development",
        choices=["development", "production", "testing"],
        help="Configuration environment"
    )
    parser.add_argument(
        "--database-url",
        help="Database URL (overrides config)"
    )
    parser.add_argument(
        "--target",
        help="Target migration version"
    )
    
    args = parser.parse_args()
    
    # Create migrator
    migrator = DatabaseMigrator(
        database_url=args.database_url,
        config_name=args.config
    )
    
    # Run command
    if args.command == "migrate":
        migrator.migrate(target_version=args.target)
    elif args.command == "status":
        migrator.status()


if __name__ == "__main__":
    main()