import os from sqlalchemy import create_engine, event, text from sqlalchemy.orm import sessionmaker from dotenv import load_dotenv from models import Base load_dotenv() DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./stocking.db") _is_postgres = DATABASE_URL.startswith("postgresql") connect_args = {"sslmode": "require"} if _is_postgres else {"check_same_thread": False} engine = create_engine(DATABASE_URL, connect_args=connect_args) # SQLite 전용: WAL 모드 (PostgreSQL에는 적용 안 함) if not _is_postgres: @event.listens_for(engine, "connect") def set_wal_mode(dbapi_connection, connection_record): """WAL 모드: 스케줄러 write + API read 동시 처리 가능.""" cursor = dbapi_connection.cursor() cursor.execute("PRAGMA journal_mode=WAL") cursor.close() SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def init_db(): # watchlist 테이블 스키마 마이그레이션: SQLite에만 적용 (PostgreSQL은 PRAGMA 없음) if not _is_postgres: with engine.connect() as conn: cols = [row[1] for row in conn.execute(text("PRAGMA table_info(watchlist)")).fetchall()] if cols and "user_id" not in cols: conn.execute(text("DROP TABLE watchlist")) conn.commit() Base.metadata.create_all(bind=engine) # annual_data 인덱스 마이그레이션: create_all은 기존 테이블에 인덱스를 추가하지 않으므로 직접 생성 with engine.connect() as conn: conn.execute(text( "CREATE INDEX IF NOT EXISTS ix_annual_data_ticker_year_is_estimate " "ON annual_data (ticker, year, is_estimate)" )) conn.execute(text( "CREATE INDEX IF NOT EXISTS ix_annual_data_ticker_is_estimate " "ON annual_data (ticker, is_estimate)" )) conn.commit() def get_db(): db = SessionLocal() try: yield db finally: db.close()