File size: 6,161 Bytes
010f0b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
from pathlib import Path
from urllib.parse import quote_plus

from dotenv import load_dotenv
from sqlalchemy import create_engine

_env_dir = Path(__file__).resolve().parent
load_dotenv(_env_dir / ".env.local", override=False)   # 本地开发优先
load_dotenv(_env_dir / ".env", override=False)          # 兜底(Docker/线上)
from sqlalchemy.orm import sessionmaker, declarative_base

# 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname
_database_url = os.environ.get("DATABASE_URL", "").strip()

if _database_url:
    SQLALCHEMY_DATABASE_URL = _database_url
else:
    PG_HOST = os.environ.get("PG_HOST", "localhost")
    PG_PORT = os.environ.get("PG_PORT", "5432")
    PG_USER = os.environ.get("PG_USER", "postgres")
    PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres")
    PG_DB = os.environ.get("PG_DB", "vector_match")
    _pw = quote_plus(PG_PASSWORD)
    SQLALCHEMY_DATABASE_URL = (
        f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}"
    )

PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match")

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    pool_pre_ping=True,
    pool_size=20,
    pool_recycle=180,
    pool_timeout=60,
    max_overflow=10,
)

# 每次连接自动切换到 vector_match schema
from sqlalchemy import event

@event.listens_for(engine, "connect")
def _set_search_path(dbapi_conn, connection_record):
    cursor = dbapi_conn.cursor()
    cursor.execute(f"SET search_path TO {PG_SCHEMA}, public")
    cursor.close()

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(DATA_DIR, exist_ok=True)

_SCHEMA_TABLES = (
    "vector_match_task",
    "vector_dataset",
    "vector_data_row",
    "vector_embedding",
    "match_result",
)


def get_db():
    db = SessionLocal()
    try:
        yield db
    except Exception:
        db.rollback()
        raise
    finally:
        db.close()


def _table_sql_name(table: str) -> str:
    if engine.dialect.name == "postgresql":
        return f'"{PG_SCHEMA}"."{table}"'
    return table


def _column_names_conn(conn, table: str) -> set:
    """
    与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录,
    否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。
    """
    from sqlalchemy import inspect

    insp = inspect(conn)
    schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
    return {c["name"] for c in insp.get_columns(table, schema=schema)}


def _ensure_is_archived_column():
    """旧库无 is_archived 时补列。"""
    from sqlalchemy import inspect, text

    insp = inspect(engine)
    schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
    try:
        cols = insp.get_columns("vector_match_task", schema=schema)
    except Exception:
        return
    if any(c["name"] == "is_archived" for c in cols):
        return
    ft = _table_sql_name("vector_match_task")
    ddl = (
        f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
        if engine.dialect.name == "postgresql"
        else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
    )
    with engine.begin() as conn:
        conn.execute(text(ddl))


def _ensure_time_is_delete_columns():
    """
    统一:created_at→created_time;任务表 updated_at→updated_time;
    各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。
    """
    from sqlalchemy import text

    ft_task = _table_sql_name("vector_match_task")

    with engine.begin() as conn:
        cols = _column_names_conn(conn, "vector_match_task")
        if "created_at" in cols and "created_time" not in cols:
            conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time"))
        if "updated_at" in cols and "updated_time" not in cols:
            conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time"))
        if "is_deleted" in cols and "is_delete" not in cols:
            conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete"))
        cols = _column_names_conn(conn, "vector_match_task")
        if "deleted_at" in cols:
            if "is_delete" not in cols:
                conn.execute(
                    text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
                )
            conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL"))
            conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at"))

    for table in _SCHEMA_TABLES:
        ft = _table_sql_name(table)
        with engine.begin() as conn:
            cols = _column_names_conn(conn, table)
            if "created_at" in cols and "created_time" not in cols:
                conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time"))
            cols = _column_names_conn(conn, table)
            if "is_delete" not in cols:
                conn.execute(
                    text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
                )
            cols = _column_names_conn(conn, table)
            if "updated_time" not in cols:
                conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL"))
                if "created_time" in cols:
                    conn.execute(
                        text(
                            f"UPDATE {ft} SET updated_time = created_time "
                            f"WHERE updated_time IS NULL"
                        )
                    )


def init_db():
    from models import Base as ModelBase  # noqa: F401
    from sqlalchemy import text
    with engine.connect() as conn:
        conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}"))
        conn.commit()
    ModelBase.metadata.create_all(bind=engine)
    _ensure_is_archived_column()
    _ensure_time_is_delete_columns()