File size: 1,371 Bytes
6180059
04a921d
81fb169
04a921d
1d2a7cd
 
04a921d
6180059
 
04a921d
6180059
 
 
04a921d
81fb169
 
 
 
 
 
 
 
 
04a921d
6180059
 
 
81fb169
 
 
04a921d
 
6180059
04a921d
 
 
 
 
6180059
04a921d
 
 
1d2a7cd
6180059
 
1d2a7cd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
from typing import AsyncGenerator
import ssl

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base

# Get DATABASE_URL and ensure it uses asyncpg
DATABASE_URL = os.getenv("DATABASE_URL", "")

# Force asyncpg driver if not specified
if DATABASE_URL.startswith("postgresql://"):
    DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)

# Remove sslmode parameter if present (asyncpg doesn't support it)
if "?sslmode=" in DATABASE_URL:
    DATABASE_URL = DATABASE_URL.split("?sslmode=")[0]

# Create SSL context for asyncpg
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE

engine = create_async_engine(
    DATABASE_URL,
    echo=True,
    future=True,
    connect_args={
        "ssl": ssl_context,
    },
)

AsyncSessionLocal = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,
)

Base = declarative_base()


async def get_session() -> AsyncGenerator[AsyncSession, None]:
    """Get database session."""
    async with AsyncSessionLocal() as session:
        yield session


async def init_db() -> None:
    """Initialize database - create all tables."""
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)