Spaces:
Running
Running
| """ | |
| Async migration system for SurrealDB using the official Python client. | |
| Based on patterns from sblpy migration system. | |
| """ | |
| from typing import List | |
| from loguru import logger | |
| from .repository import db_connection, repo_query | |
| class AsyncMigration: | |
| """ | |
| Handles individual migration operations with async support. | |
| """ | |
| def __init__(self, sql: str) -> None: | |
| """Initialize migration with SQL content.""" | |
| self.sql = sql | |
| def from_file(cls, file_path: str) -> "AsyncMigration": | |
| """Create migration from SQL file.""" | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| raw_content = file.read() | |
| # Clean up SQL content | |
| lines = [] | |
| for line in raw_content.split("\n"): | |
| line = line.strip() | |
| if line and not line.startswith("--"): | |
| lines.append(line) | |
| sql = " ".join(lines) | |
| return cls(sql) | |
| async def run(self, bump: bool = True) -> None: | |
| """Run the migration.""" | |
| try: | |
| # Only execute SQL if there's actual content (not just comments/whitespace) | |
| if self.sql and self.sql.strip(): | |
| async with db_connection() as connection: | |
| await connection.query(self.sql) | |
| # Always bump/lower version even for empty migrations | |
| if bump: | |
| await bump_version() | |
| else: | |
| await lower_version() | |
| except Exception as e: | |
| logger.error(f"Migration failed: {str(e)}") | |
| raise | |
| class AsyncMigrationRunner: | |
| """ | |
| Handles running multiple migrations in sequence. | |
| """ | |
| def __init__( | |
| self, | |
| up_migrations: List[AsyncMigration], | |
| down_migrations: List[AsyncMigration], | |
| ) -> None: | |
| """Initialize runner with migration lists.""" | |
| self.up_migrations = up_migrations | |
| self.down_migrations = down_migrations | |
| async def run_all(self) -> None: | |
| """Run all pending up migrations.""" | |
| current_version = await get_latest_version() | |
| for i in range(current_version, len(self.up_migrations)): | |
| logger.info(f"Running migration {i + 1}") | |
| await self.up_migrations[i].run(bump=True) | |
| async def run_one_up(self) -> None: | |
| """Run one up migration.""" | |
| current_version = await get_latest_version() | |
| if current_version < len(self.up_migrations): | |
| logger.info(f"Running migration {current_version + 1}") | |
| await self.up_migrations[current_version].run(bump=True) | |
| async def run_one_down(self) -> None: | |
| """Run one down migration.""" | |
| current_version = await get_latest_version() | |
| if current_version > 0: | |
| logger.info(f"Rolling back migration {current_version}") | |
| await self.down_migrations[current_version - 1].run(bump=False) | |
| class AsyncMigrationManager: | |
| """ | |
| Main migration manager with async support. | |
| """ | |
| def __init__(self): | |
| """Initialize migration manager.""" | |
| self.up_migrations = [ | |
| AsyncMigration.from_file("migrations/1.surrealql"), | |
| AsyncMigration.from_file("migrations/2.surrealql"), | |
| AsyncMigration.from_file("migrations/3.surrealql"), | |
| AsyncMigration.from_file("migrations/4.surrealql"), | |
| AsyncMigration.from_file("migrations/5.surrealql"), | |
| AsyncMigration.from_file("migrations/6.surrealql"), | |
| AsyncMigration.from_file("migrations/7.surrealql"), | |
| AsyncMigration.from_file("migrations/8.surrealql"), | |
| AsyncMigration.from_file("migrations/9.surrealql"), | |
| AsyncMigration.from_file("migrations/10.surrealql"), | |
| AsyncMigration.from_file("migrations/11.surrealql"), | |
| AsyncMigration.from_file("migrations/12.surrealql"), | |
| AsyncMigration.from_file("migrations/13.surrealql"), | |
| AsyncMigration.from_file("migrations/14.surrealql"), | |
| AsyncMigration.from_file("migrations/15.surrealql"), | |
| AsyncMigration.from_file("migrations/16.surrealql"), | |
| AsyncMigration.from_file("migrations/17.surrealql"), | |
| AsyncMigration.from_file("migrations/18.surrealql"), | |
| ] | |
| self.down_migrations = [ | |
| AsyncMigration.from_file("migrations/1_down.surrealql"), | |
| AsyncMigration.from_file("migrations/2_down.surrealql"), | |
| AsyncMigration.from_file("migrations/3_down.surrealql"), | |
| AsyncMigration.from_file("migrations/4_down.surrealql"), | |
| AsyncMigration.from_file("migrations/5_down.surrealql"), | |
| AsyncMigration.from_file("migrations/6_down.surrealql"), | |
| AsyncMigration.from_file("migrations/7_down.surrealql"), | |
| AsyncMigration.from_file("migrations/8_down.surrealql"), | |
| AsyncMigration.from_file("migrations/9_down.surrealql"), | |
| AsyncMigration.from_file("migrations/10_down.surrealql"), | |
| AsyncMigration.from_file("migrations/11_down.surrealql"), | |
| AsyncMigration.from_file("migrations/12_down.surrealql"), | |
| AsyncMigration.from_file("migrations/13_down.surrealql"), | |
| AsyncMigration.from_file("migrations/14_down.surrealql"), | |
| AsyncMigration.from_file("migrations/15_down.surrealql"), | |
| AsyncMigration.from_file("migrations/16_down.surrealql"), | |
| AsyncMigration.from_file("migrations/17_down.surrealql"), | |
| AsyncMigration.from_file("migrations/18_down.surrealql"), | |
| ] | |
| self.runner = AsyncMigrationRunner( | |
| up_migrations=self.up_migrations, | |
| down_migrations=self.down_migrations, | |
| ) | |
| async def get_current_version(self) -> int: | |
| """Get current database version.""" | |
| return await get_latest_version() | |
| async def needs_migration(self) -> bool: | |
| """Check if migration is needed.""" | |
| current_version = await self.get_current_version() | |
| return current_version < len(self.up_migrations) | |
| async def run_migration_up(self): | |
| """Run all pending migrations.""" | |
| current_version = await self.get_current_version() | |
| logger.info(f"Current version before migration: {current_version}") | |
| if await self.needs_migration(): | |
| try: | |
| await self.runner.run_all() | |
| new_version = await self.get_current_version() | |
| logger.info(f"Migration successful. New version: {new_version}") | |
| except Exception as e: | |
| logger.error(f"Migration failed: {str(e)}") | |
| raise | |
| else: | |
| logger.info("Database is already at the latest version") | |
| # Database version management functions | |
| async def get_latest_version() -> int: | |
| """Get the latest version from the migrations table.""" | |
| try: | |
| versions = await get_all_versions() | |
| if not versions: | |
| return 0 | |
| return max(version["version"] for version in versions) | |
| except Exception: | |
| # If migrations table doesn't exist, we're at version 0 | |
| return 0 | |
| async def get_all_versions() -> List[dict]: | |
| """Get all versions from the migrations table.""" | |
| try: | |
| result = await repo_query("SELECT * FROM _sbl_migrations ORDER BY version;") | |
| return result | |
| except Exception: | |
| # If table doesn't exist, return empty list | |
| return [] | |
| async def bump_version() -> None: | |
| """Bump the version by adding a new entry to migrations table.""" | |
| current_version = await get_latest_version() | |
| new_version = current_version + 1 | |
| await repo_query( | |
| f"CREATE _sbl_migrations:{new_version} SET version = {new_version}, applied_at = time::now();", | |
| ) | |
| async def lower_version() -> None: | |
| """Lower the version by removing the latest entry from migrations table.""" | |
| current_version = await get_latest_version() | |
| if current_version > 0: | |
| await repo_query(f"DELETE _sbl_migrations:{current_version};") | |