Spaces:
Running
Running
File size: 6,161 Bytes
010f0b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import os
from pathlib import Path
from urllib.parse import quote_plus
from dotenv import load_dotenv
from sqlalchemy import create_engine
_env_dir = Path(__file__).resolve().parent
load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先
load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上)
from sqlalchemy.orm import sessionmaker, declarative_base
# 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname
_database_url = os.environ.get("DATABASE_URL", "").strip()
if _database_url:
SQLALCHEMY_DATABASE_URL = _database_url
else:
PG_HOST = os.environ.get("PG_HOST", "localhost")
PG_PORT = os.environ.get("PG_PORT", "5432")
PG_USER = os.environ.get("PG_USER", "postgres")
PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres")
PG_DB = os.environ.get("PG_DB", "vector_match")
_pw = quote_plus(PG_PASSWORD)
SQLALCHEMY_DATABASE_URL = (
f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}"
)
PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match")
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True,
pool_size=20,
pool_recycle=180,
pool_timeout=60,
max_overflow=10,
)
# 每次连接自动切换到 vector_match schema
from sqlalchemy import event
@event.listens_for(engine, "connect")
def _set_search_path(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
cursor.execute(f"SET search_path TO {PG_SCHEMA}, public")
cursor.close()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(DATA_DIR, exist_ok=True)
_SCHEMA_TABLES = (
"vector_match_task",
"vector_dataset",
"vector_data_row",
"vector_embedding",
"match_result",
)
def get_db():
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()
def _table_sql_name(table: str) -> str:
if engine.dialect.name == "postgresql":
return f'"{PG_SCHEMA}"."{table}"'
return table
def _column_names_conn(conn, table: str) -> set:
"""
与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录,
否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。
"""
from sqlalchemy import inspect
insp = inspect(conn)
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
return {c["name"] for c in insp.get_columns(table, schema=schema)}
def _ensure_is_archived_column():
"""旧库无 is_archived 时补列。"""
from sqlalchemy import inspect, text
insp = inspect(engine)
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
try:
cols = insp.get_columns("vector_match_task", schema=schema)
except Exception:
return
if any(c["name"] == "is_archived" for c in cols):
return
ft = _table_sql_name("vector_match_task")
ddl = (
f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
if engine.dialect.name == "postgresql"
else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
)
with engine.begin() as conn:
conn.execute(text(ddl))
def _ensure_time_is_delete_columns():
"""
统一:created_at→created_time;任务表 updated_at→updated_time;
各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。
"""
from sqlalchemy import text
ft_task = _table_sql_name("vector_match_task")
with engine.begin() as conn:
cols = _column_names_conn(conn, "vector_match_task")
if "created_at" in cols and "created_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time"))
if "updated_at" in cols and "updated_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time"))
if "is_deleted" in cols and "is_delete" not in cols:
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete"))
cols = _column_names_conn(conn, "vector_match_task")
if "deleted_at" in cols:
if "is_delete" not in cols:
conn.execute(
text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
)
conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL"))
conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at"))
for table in _SCHEMA_TABLES:
ft = _table_sql_name(table)
with engine.begin() as conn:
cols = _column_names_conn(conn, table)
if "created_at" in cols and "created_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time"))
cols = _column_names_conn(conn, table)
if "is_delete" not in cols:
conn.execute(
text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
)
cols = _column_names_conn(conn, table)
if "updated_time" not in cols:
conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL"))
if "created_time" in cols:
conn.execute(
text(
f"UPDATE {ft} SET updated_time = created_time "
f"WHERE updated_time IS NULL"
)
)
def init_db():
from models import Base as ModelBase # noqa: F401
from sqlalchemy import text
with engine.connect() as conn:
conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}"))
conn.commit()
ModelBase.metadata.create_all(bind=engine)
_ensure_is_archived_column()
_ensure_time_is_delete_columns()
|