| import os |
| from sqlalchemy import create_engine, event, text |
| from sqlalchemy.orm import sessionmaker |
| from dotenv import load_dotenv |
| from models import Base |
|
|
| load_dotenv() |
| DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./stocking.db") |
|
|
| _is_postgres = DATABASE_URL.startswith("postgresql") |
|
|
| connect_args = {"sslmode": "require"} if _is_postgres else {"check_same_thread": False} |
| engine = create_engine(DATABASE_URL, connect_args=connect_args) |
|
|
| |
| if not _is_postgres: |
| @event.listens_for(engine, "connect") |
| def set_wal_mode(dbapi_connection, connection_record): |
| """WAL ๋ชจ๋: ์ค์ผ์ค๋ฌ write + API read ๋์ ์ฒ๋ฆฌ ๊ฐ๋ฅ.""" |
| cursor = dbapi_connection.cursor() |
| cursor.execute("PRAGMA journal_mode=WAL") |
| cursor.close() |
|
|
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
| def init_db(): |
| |
| if not _is_postgres: |
| with engine.connect() as conn: |
| cols = [row[1] for row in conn.execute(text("PRAGMA table_info(watchlist)")).fetchall()] |
| if cols and "user_id" not in cols: |
| conn.execute(text("DROP TABLE watchlist")) |
| conn.commit() |
| Base.metadata.create_all(bind=engine) |
| |
| with engine.connect() as conn: |
| conn.execute(text( |
| "CREATE INDEX IF NOT EXISTS ix_annual_data_ticker_year_is_estimate " |
| "ON annual_data (ticker, year, is_estimate)" |
| )) |
| conn.execute(text( |
| "CREATE INDEX IF NOT EXISTS ix_annual_data_ticker_is_estimate " |
| "ON annual_data (ticker, is_estimate)" |
| )) |
| conn.commit() |
|
|
|
|
| def get_db(): |
| db = SessionLocal() |
| try: |
| yield db |
| finally: |
| db.close() |
|
|