""" Async migration system for SurrealDB using the official Python client. Based on patterns from sblpy migration system. """ from typing import List from loguru import logger from .repository import db_connection, repo_query class AsyncMigration: """ Handles individual migration operations with async support. """ def __init__(self, sql: str) -> None: """Initialize migration with SQL content.""" self.sql = sql @classmethod def from_file(cls, file_path: str) -> "AsyncMigration": """Create migration from SQL file.""" with open(file_path, "r", encoding="utf-8") as file: raw_content = file.read() # Clean up SQL content lines = [] for line in raw_content.split("\n"): line = line.strip() if line and not line.startswith("--"): lines.append(line) sql = " ".join(lines) return cls(sql) async def run(self, bump: bool = True) -> None: """Run the migration.""" try: # Only execute SQL if there's actual content (not just comments/whitespace) if self.sql and self.sql.strip(): async with db_connection() as connection: await connection.query(self.sql) # Always bump/lower version even for empty migrations if bump: await bump_version() else: await lower_version() except Exception as e: logger.error(f"Migration failed: {str(e)}") raise class AsyncMigrationRunner: """ Handles running multiple migrations in sequence. """ def __init__( self, up_migrations: List[AsyncMigration], down_migrations: List[AsyncMigration], ) -> None: """Initialize runner with migration lists.""" self.up_migrations = up_migrations self.down_migrations = down_migrations async def run_all(self) -> None: """Run all pending up migrations.""" current_version = await get_latest_version() for i in range(current_version, len(self.up_migrations)): logger.info(f"Running migration {i + 1}") await self.up_migrations[i].run(bump=True) async def run_one_up(self) -> None: """Run one up migration.""" current_version = await get_latest_version() if current_version < len(self.up_migrations): logger.info(f"Running migration {current_version + 1}") await self.up_migrations[current_version].run(bump=True) async def run_one_down(self) -> None: """Run one down migration.""" current_version = await get_latest_version() if current_version > 0: logger.info(f"Rolling back migration {current_version}") await self.down_migrations[current_version - 1].run(bump=False) class AsyncMigrationManager: """ Main migration manager with async support. """ def __init__(self): """Initialize migration manager.""" self.up_migrations = [ AsyncMigration.from_file("migrations/1.surrealql"), AsyncMigration.from_file("migrations/2.surrealql"), AsyncMigration.from_file("migrations/3.surrealql"), AsyncMigration.from_file("migrations/4.surrealql"), AsyncMigration.from_file("migrations/5.surrealql"), AsyncMigration.from_file("migrations/6.surrealql"), AsyncMigration.from_file("migrations/7.surrealql"), AsyncMigration.from_file("migrations/8.surrealql"), AsyncMigration.from_file("migrations/9.surrealql"), AsyncMigration.from_file("migrations/10.surrealql"), AsyncMigration.from_file("migrations/11.surrealql"), AsyncMigration.from_file("migrations/12.surrealql"), AsyncMigration.from_file("migrations/13.surrealql"), AsyncMigration.from_file("migrations/14.surrealql"), AsyncMigration.from_file("migrations/15.surrealql"), AsyncMigration.from_file("migrations/16.surrealql"), AsyncMigration.from_file("migrations/17.surrealql"), AsyncMigration.from_file("migrations/18.surrealql"), ] self.down_migrations = [ AsyncMigration.from_file("migrations/1_down.surrealql"), AsyncMigration.from_file("migrations/2_down.surrealql"), AsyncMigration.from_file("migrations/3_down.surrealql"), AsyncMigration.from_file("migrations/4_down.surrealql"), AsyncMigration.from_file("migrations/5_down.surrealql"), AsyncMigration.from_file("migrations/6_down.surrealql"), AsyncMigration.from_file("migrations/7_down.surrealql"), AsyncMigration.from_file("migrations/8_down.surrealql"), AsyncMigration.from_file("migrations/9_down.surrealql"), AsyncMigration.from_file("migrations/10_down.surrealql"), AsyncMigration.from_file("migrations/11_down.surrealql"), AsyncMigration.from_file("migrations/12_down.surrealql"), AsyncMigration.from_file("migrations/13_down.surrealql"), AsyncMigration.from_file("migrations/14_down.surrealql"), AsyncMigration.from_file("migrations/15_down.surrealql"), AsyncMigration.from_file("migrations/16_down.surrealql"), AsyncMigration.from_file("migrations/17_down.surrealql"), AsyncMigration.from_file("migrations/18_down.surrealql"), ] self.runner = AsyncMigrationRunner( up_migrations=self.up_migrations, down_migrations=self.down_migrations, ) async def get_current_version(self) -> int: """Get current database version.""" return await get_latest_version() async def needs_migration(self) -> bool: """Check if migration is needed.""" current_version = await self.get_current_version() return current_version < len(self.up_migrations) async def run_migration_up(self): """Run all pending migrations.""" current_version = await self.get_current_version() logger.info(f"Current version before migration: {current_version}") if await self.needs_migration(): try: await self.runner.run_all() new_version = await self.get_current_version() logger.info(f"Migration successful. New version: {new_version}") except Exception as e: logger.error(f"Migration failed: {str(e)}") raise else: logger.info("Database is already at the latest version") # Database version management functions async def get_latest_version() -> int: """Get the latest version from the migrations table.""" try: versions = await get_all_versions() if not versions: return 0 return max(version["version"] for version in versions) except Exception: # If migrations table doesn't exist, we're at version 0 return 0 async def get_all_versions() -> List[dict]: """Get all versions from the migrations table.""" try: result = await repo_query("SELECT * FROM _sbl_migrations ORDER BY version;") return result except Exception: # If table doesn't exist, return empty list return [] async def bump_version() -> None: """Bump the version by adding a new entry to migrations table.""" current_version = await get_latest_version() new_version = current_version + 1 await repo_query( f"CREATE _sbl_migrations:{new_version} SET version = {new_version}, applied_at = time::now();", ) async def lower_version() -> None: """Lower the version by removing the latest entry from migrations table.""" current_version = await get_latest_version() if current_version > 0: await repo_query(f"DELETE _sbl_migrations:{current_version};")