vector-match-api / database.py
teryryy's picture
Upload 13 files
ba016aa verified
import os
from pathlib import Path
from urllib.parse import quote_plus
from dotenv import load_dotenv
from sqlalchemy import create_engine
_env_dir = Path(__file__).resolve().parent
load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先
load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上)
from sqlalchemy.orm import sessionmaker, declarative_base
# 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname
_database_url = os.environ.get("DATABASE_URL", "").strip()
if _database_url:
SQLALCHEMY_DATABASE_URL = _database_url
else:
PG_HOST = os.environ.get("PG_HOST", "localhost")
PG_PORT = os.environ.get("PG_PORT", "5432")
PG_USER = os.environ.get("PG_USER", "postgres")
PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres")
PG_DB = os.environ.get("PG_DB", "vector_match")
_pw = quote_plus(PG_PASSWORD)
SQLALCHEMY_DATABASE_URL = (
f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}"
)
PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match")
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True,
pool_size=20,
pool_recycle=180,
pool_timeout=60,
max_overflow=10,
)
# 每次连接自动切换到 vector_match schema
from sqlalchemy import event
@event.listens_for(engine, "connect")
def _set_search_path(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
cursor.execute(f"SET search_path TO {PG_SCHEMA}, public")
cursor.close()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(DATA_DIR, exist_ok=True)
_SCHEMA_TABLES = (
"vector_match_task",
"vector_dataset",
"vector_data_row",
"vector_embedding",
"match_result",
)
def get_db():
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()
def _table_sql_name(table: str) -> str:
if engine.dialect.name == "postgresql":
return f'"{PG_SCHEMA}"."{table}"'
return table
def _column_names_conn(conn, table: str) -> set:
"""
与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录,
否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。
"""
from sqlalchemy import inspect
insp = inspect(conn)
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
return {c["name"] for c in insp.get_columns(table, schema=schema)}
def _ensure_is_archived_column():
"""旧库无 is_archived 时补列。"""
from sqlalchemy import inspect, text
insp = inspect(engine)
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
try:
cols = insp.get_columns("vector_match_task", schema=schema)
except Exception:
return
if any(c["name"] == "is_archived" for c in cols):
return
ft = _table_sql_name("vector_match_task")
ddl = (
f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
if engine.dialect.name == "postgresql"
else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
)
with engine.begin() as conn:
conn.execute(text(ddl))
def _ensure_time_is_delete_columns():
"""
统一:created_at→created_time;任务表 updated_at→updated_time;
各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。
"""
from sqlalchemy import text
ft_task = _table_sql_name("vector_match_task")
with engine.begin() as conn:
cols = _column_names_conn(conn, "vector_match_task")
if "created_at" in cols and "created_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time"))
if "updated_at" in cols and "updated_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time"))
if "is_deleted" in cols and "is_delete" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete"))
cols = _column_names_conn(conn, "vector_match_task")
if "deleted_at" in cols:
if "is_delete" not in cols:
conn.execute(
text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
)
conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL"))
conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at"))
for table in _SCHEMA_TABLES:
ft = _table_sql_name(table)
with engine.begin() as conn:
cols = _column_names_conn(conn, table)
if "created_at" in cols and "created_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time"))
cols = _column_names_conn(conn, table)
if "is_delete" not in cols:
conn.execute(
text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
)
cols = _column_names_conn(conn, table)
if "updated_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL"))
if "created_time" in cols:
conn.execute(
text(
f"UPDATE {ft} SET updated_time = created_time "
f"WHERE updated_time IS NULL"
)
)
def init_db():
from models import Base as ModelBase # noqa: F401
from sqlalchemy import text
with engine.connect() as conn:
conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}"))
conn.commit()
ModelBase.metadata.create_all(bind=engine)
_ensure_is_archived_column()
_ensure_time_is_delete_columns()