File size: 2,255 Bytes
9e65b56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
SQLAlchemy 2.0 ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์—”์ง„/์„ธ์…˜ ์„ค์ •.

๋™๊ธฐ ์„ธ์…˜ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ตฌ์„ฑํ•˜๋ฉฐ, FastAPI ์˜์กด์„ฑ ์ฃผ์ž…(get_db)์„ ์ œ๊ณตํ•œ๋‹ค.
๊ธฐ๋ณธ๊ฐ’์€ ๋กœ์ปฌ ๋‹จ์ผ ์‚ฌ์šฉ์ž MVP์— ๋งž์ถฐ GovOn ํ™ˆ ๋””๋ ‰ํ„ฐ๋ฆฌ ์•„๋ž˜ SQLite ํŒŒ์ผ์„ ์‚ฌ์šฉํ•œ๋‹ค.
"""

import logging
import os
from pathlib import Path
from typing import Generator

from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# ์—”์ง„ & ์„ธ์…˜ ํŒฉํ† ๋ฆฌ
# ---------------------------------------------------------------------------

_DEFAULT_GOVON_HOME = Path(os.getenv("GOVON_HOME", Path.home() / ".govon"))
_DEFAULT_DATABASE_URL = f"sqlite:///{_DEFAULT_GOVON_HOME / 'metadata.sqlite3'}"

DATABASE_URL: str = os.getenv("DATABASE_URL", _DEFAULT_DATABASE_URL)

if DATABASE_URL == _DEFAULT_DATABASE_URL:
    logger.warning(
        "DATABASE_URL ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•„ ๋กœ์ปฌ SQLite ๊ธฐ๋ณธ๊ฐ’์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. "
        "๋ณ„๋„ RDBMS๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด DATABASE_URL์„ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•˜์„ธ์š”."
    )

engine_kwargs = {
    "echo": os.getenv("SQL_ECHO", "").lower() in ("1", "true"),
}
if DATABASE_URL.startswith("sqlite:///"):
    _DEFAULT_GOVON_HOME.mkdir(parents=True, exist_ok=True)
    engine_kwargs["connect_args"] = {"check_same_thread": False}
else:
    engine_kwargs.update(
        {
            "pool_size": 10,
            "max_overflow": 20,
            "pool_pre_ping": True,
            "pool_recycle": 3600,
        }
    )

engine = create_engine(DATABASE_URL, **engine_kwargs)

SessionLocal = sessionmaker(
    bind=engine,
    autocommit=False,
    autoflush=False,
)


# ---------------------------------------------------------------------------
# FastAPI ์˜์กด์„ฑ ์ฃผ์ž…
# ---------------------------------------------------------------------------


def get_db() -> Generator[Session, None, None]:
    """FastAPI Depends()์šฉ ์„ธ์…˜ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ.

    ์‚ฌ์šฉ ์˜ˆ์‹œ::

        @router.get("/docs")
        def list_docs(db: Session = Depends(get_db)):
            ...
    """
    db = SessionLocal()
    try:
        yield db
    finally:
        db.rollback()
        db.close()