import os from pathlib import Path from urllib.parse import quote_plus from dotenv import load_dotenv from sqlalchemy import create_engine _env_dir = Path(__file__).resolve().parent load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先 load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上) from sqlalchemy.orm import sessionmaker, declarative_base # 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname _database_url = os.environ.get("DATABASE_URL", "").strip() if _database_url: SQLALCHEMY_DATABASE_URL = _database_url else: PG_HOST = os.environ.get("PG_HOST", "localhost") PG_PORT = os.environ.get("PG_PORT", "5432") PG_USER = os.environ.get("PG_USER", "postgres") PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres") PG_DB = os.environ.get("PG_DB", "vector_match") _pw = quote_plus(PG_PASSWORD) SQLALCHEMY_DATABASE_URL = ( f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}" ) PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match") engine = create_engine( SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, pool_size=20, pool_recycle=180, pool_timeout=60, max_overflow=10, ) # 每次连接自动切换到 vector_match schema from sqlalchemy import event @event.listens_for(engine, "connect") def _set_search_path(dbapi_conn, connection_record): cursor = dbapi_conn.cursor() cursor.execute(f"SET search_path TO {PG_SCHEMA}, public") cursor.close() SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() DATA_DIR = os.path.join(os.path.dirname(__file__), "data") os.makedirs(DATA_DIR, exist_ok=True) _SCHEMA_TABLES = ( "vector_match_task", "vector_dataset", "vector_data_row", "vector_embedding", "match_result", ) def get_db(): db = SessionLocal() try: yield db except Exception: db.rollback() raise finally: db.close() def _table_sql_name(table: str) -> str: if engine.dialect.name == "postgresql": return f'"{PG_SCHEMA}"."{table}"' return table def _column_names_conn(conn, table: str) -> set: """ 与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录, 否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。 """ from sqlalchemy import inspect insp = inspect(conn) schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None return {c["name"] for c in insp.get_columns(table, schema=schema)} def _ensure_is_archived_column(): """旧库无 is_archived 时补列。""" from sqlalchemy import inspect, text insp = inspect(engine) schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None try: cols = insp.get_columns("vector_match_task", schema=schema) except Exception: return if any(c["name"] == "is_archived" for c in cols): return ft = _table_sql_name("vector_match_task") ddl = ( f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0" if engine.dialect.name == "postgresql" else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0" ) with engine.begin() as conn: conn.execute(text(ddl)) def _ensure_time_is_delete_columns(): """ 统一:created_at→created_time;任务表 updated_at→updated_time; 各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。 """ from sqlalchemy import text ft_task = _table_sql_name("vector_match_task") with engine.begin() as conn: cols = _column_names_conn(conn, "vector_match_task") if "created_at" in cols and "created_time" not in cols: conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time")) if "updated_at" in cols and "updated_time" not in cols: conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time")) if "is_deleted" in cols and "is_delete" not in cols: conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete")) cols = _column_names_conn(conn, "vector_match_task") if "deleted_at" in cols: if "is_delete" not in cols: conn.execute( text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0") ) conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL")) conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at")) for table in _SCHEMA_TABLES: ft = _table_sql_name(table) with engine.begin() as conn: cols = _column_names_conn(conn, table) if "created_at" in cols and "created_time" not in cols: conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time")) cols = _column_names_conn(conn, table) if "is_delete" not in cols: conn.execute( text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0") ) cols = _column_names_conn(conn, table) if "updated_time" not in cols: conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL")) if "created_time" in cols: conn.execute( text( f"UPDATE {ft} SET updated_time = created_time " f"WHERE updated_time IS NULL" ) ) def init_db(): from models import Base as ModelBase # noqa: F401 from sqlalchemy import text with engine.connect() as conn: conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}")) conn.commit() ModelBase.metadata.create_all(bind=engine) _ensure_is_archived_column() _ensure_time_is_delete_columns()