File size: 4,348 Bytes
697c967
 
6a3de9e
697c967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3de9e
 
 
 
 
 
 
697c967
 
 
 
 
 
 
 
 
 
6a3de9e
 
 
 
 
 
 
697c967
 
 
 
 
6a3de9e
 
 
 
 
697c967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3de9e
 
697c967
 
6a3de9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, SQLModel, create_engine
from typing import AsyncGenerator
from contextlib import asynccontextmanager
import os
from config.settings import settings

# Create the async database engine
db_url = settings.database_url

if db_url.startswith("postgresql://"):
    # Convert to asyncpg format
    db_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
elif db_url.startswith("postgresql+asyncpg://"):
    # Already in correct format
    db_url = db_url
elif db_url.startswith("sqlite://") and not db_url.startswith("sqlite+aiosqlite"):
    # Convert to aiosqlite format
    db_url = db_url.replace("sqlite://", "sqlite+aiosqlite://", 1)
elif db_url.startswith("sqlite+aiosqlite"):
    # Already in correct format
    db_url = db_url

# For Neon PostgreSQL with asyncpg, SSL is handled automatically
# The issue is with URL parameters that asyncpg doesn't expect
if "postgresql+asyncpg" in db_url and "?sslmode=" in db_url:
    # Extract the base URL without query parameters
    base_url = db_url.split('?')[0]
    # For Neon, we often just need the base URL as asyncpg handles SSL automatically
    db_url = base_url

# Create sync database URL (convert async URLs to sync format)
sync_db_url = db_url
if "postgresql+asyncpg://" in sync_db_url:
    sync_db_url = sync_db_url.replace("postgresql+asyncpg://", "postgresql://")
elif "sqlite+aiosqlite://" in sync_db_url:
    sync_db_url = sync_db_url.replace("sqlite+aiosqlite://", "sqlite://")

# Set appropriate engine options based on database type
if "postgresql" in db_url:
    # For PostgreSQL, use asyncpg with proper SSL handling
    async_engine = create_async_engine(
        db_url,
        echo=settings.db_echo,  # Set to True for SQL query logging during development
        pool_pre_ping=True,  # Verify connections before use
        pool_recycle=300,  # Recycle connections every 5 minutes
        # SSL is handled automatically by asyncpg for Neon
    )
    # Create sync engine for synchronous operations
    sync_engine = create_engine(
        sync_db_url,
        echo=settings.db_echo,
        pool_pre_ping=True,
        pool_recycle=300,
    )
else:  # SQLite
    async_engine = create_async_engine(
        db_url,
        echo=settings.db_echo,  # Set to True for SQL query logging during development
    )
    # Create sync engine for synchronous operations
    sync_engine = create_engine(
        sync_db_url,
        echo=settings.db_echo,
    )

@asynccontextmanager
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
    """
    Async context manager for database sessions.
    Ensures the session is properly closed after use.
    """
    async with AsyncSession(async_engine) as session:
        try:
            yield session
        finally:
            await session.close()

async def get_session_dep():
    """
    Dependency function for FastAPI to provide async database sessions with proper
    transaction management.
    """
    async with AsyncSession(async_engine) as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise
        finally:
            await session.close()

def get_session() -> Session:
    """
    Dependency function to get a synchronous database session.

    Yields:
        Session: SQLModel database session

    Example:
        ```python
        @app.get("/items")
        def get_items(session: Session = Depends(get_session)):
            items = session.exec(select(Item)).all()
            return items
        ```
    """
    with Session(sync_engine) as session:
        yield session


def get_sync_session() -> Session:
    """
    Generator function to get a synchronous database session for use in synchronous contexts like MCP servers.

    Yields:
        Session: SQLModel synchronous database session
    """
    session = Session(sync_engine)
    try:
        yield session
    finally:
        session.close()


def create_sync_session() -> Session:
    """
    Create and return a synchronous database session for direct use.

    Returns:
        Session: SQLModel synchronous database session
    """
    return Session(sync_engine)