File size: 2,126 Bytes
50c20bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
DBService Database Initialization

Provides utilities for creating and managing database tables.
"""

import logging
from sqlalchemy.ext.asyncio import AsyncEngine

from services.db_service.config import DBServiceConfig

logger = logging.getLogger(__name__)


async def init_database(engine: AsyncEngine) -> None:
    """
    Initialize database tables based on registered models.
    
    Creates all tables defined in DBServiceConfig.all_models.
    """
    DBServiceConfig.assert_registered()
    
    if not DBServiceConfig.db_base:
        raise RuntimeError(
            "No database base registered! "
            "Pass db_base parameter to DBServiceConfig.register()"
        )
    
    logger.info("Creating database tables...")
    
    async with engine.begin() as conn:
        await conn.run_sync(DBServiceConfig.db_base.metadata.create_all)
    
    model_count = len(DBServiceConfig.all_models)
    logger.info(f"βœ… Database initialized with {model_count} models")


async def drop_database(engine: AsyncEngine) -> None:
    """Drop all database tables. WARNING: Deletes all data!"""
    DBServiceConfig.assert_registered()
    
    if not DBServiceConfig.db_base:
        raise RuntimeError("No database base registered!")
    
    logger.warning("⚠️  Dropping all database tables...")
    
    async with engine.begin() as conn:
        await conn.run_sync(DBServiceConfig.db_base.metadata.drop_all)
    
    logger.info("βœ… All tables dropped")


async def reset_database(engine: AsyncEngine) -> None:
    """Reset database (drop + create). WARNING: Deletes all data!"""
    await drop_database(engine)
    await init_database(engine)
    logger.info("βœ… Database reset complete")


def get_registered_models() -> list:
    """Get list of all registered models."""
    DBServiceConfig.assert_registered()
    return DBServiceConfig.all_models


def get_model_by_name(model_name: str):
    """Get model class by name."""
    DBServiceConfig.assert_registered()
    
    for model in DBServiceConfig.all_models:
        if model.__name__ == model_name:
            return model
    
    return None