Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| from urllib.parse import quote_plus | |
| from dotenv import load_dotenv | |
| from sqlalchemy import create_engine | |
| _env_dir = Path(__file__).resolve().parent | |
| load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先 | |
| load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上) | |
| from sqlalchemy.orm import sessionmaker, declarative_base | |
| # 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname | |
| _database_url = os.environ.get("DATABASE_URL", "").strip() | |
| if _database_url: | |
| SQLALCHEMY_DATABASE_URL = _database_url | |
| else: | |
| PG_HOST = os.environ.get("PG_HOST", "localhost") | |
| PG_PORT = os.environ.get("PG_PORT", "5432") | |
| PG_USER = os.environ.get("PG_USER", "postgres") | |
| PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres") | |
| PG_DB = os.environ.get("PG_DB", "vector_match") | |
| _pw = quote_plus(PG_PASSWORD) | |
| SQLALCHEMY_DATABASE_URL = ( | |
| f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}" | |
| ) | |
| PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match") | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| pool_pre_ping=True, | |
| pool_size=20, | |
| pool_recycle=180, | |
| pool_timeout=60, | |
| max_overflow=10, | |
| ) | |
| # 每次连接自动切换到 vector_match schema | |
| from sqlalchemy import event | |
| def _set_search_path(dbapi_conn, connection_record): | |
| cursor = dbapi_conn.cursor() | |
| cursor.execute(f"SET search_path TO {PG_SCHEMA}, public") | |
| cursor.close() | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| DATA_DIR = os.path.join(os.path.dirname(__file__), "data") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| _SCHEMA_TABLES = ( | |
| "vector_match_task", | |
| "vector_dataset", | |
| "vector_data_row", | |
| "vector_embedding", | |
| "match_result", | |
| ) | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| except Exception: | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def _table_sql_name(table: str) -> str: | |
| if engine.dialect.name == "postgresql": | |
| return f'"{PG_SCHEMA}"."{table}"' | |
| return table | |
| def _column_names_conn(conn, table: str) -> set: | |
| """ | |
| 与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录, | |
| 否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。 | |
| """ | |
| from sqlalchemy import inspect | |
| insp = inspect(conn) | |
| schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None | |
| return {c["name"] for c in insp.get_columns(table, schema=schema)} | |
| def _ensure_is_archived_column(): | |
| """旧库无 is_archived 时补列。""" | |
| from sqlalchemy import inspect, text | |
| insp = inspect(engine) | |
| schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None | |
| try: | |
| cols = insp.get_columns("vector_match_task", schema=schema) | |
| except Exception: | |
| return | |
| if any(c["name"] == "is_archived" for c in cols): | |
| return | |
| ft = _table_sql_name("vector_match_task") | |
| ddl = ( | |
| f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0" | |
| if engine.dialect.name == "postgresql" | |
| else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0" | |
| ) | |
| with engine.begin() as conn: | |
| conn.execute(text(ddl)) | |
| def _ensure_time_is_delete_columns(): | |
| """ | |
| 统一:created_at→created_time;任务表 updated_at→updated_time; | |
| 各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。 | |
| """ | |
| from sqlalchemy import text | |
| ft_task = _table_sql_name("vector_match_task") | |
| with engine.begin() as conn: | |
| cols = _column_names_conn(conn, "vector_match_task") | |
| if "created_at" in cols and "created_time" not in cols: | |
| conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time")) | |
| if "updated_at" in cols and "updated_time" not in cols: | |
| conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time")) | |
| if "is_deleted" in cols and "is_delete" not in cols: | |
| conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete")) | |
| cols = _column_names_conn(conn, "vector_match_task") | |
| if "deleted_at" in cols: | |
| if "is_delete" not in cols: | |
| conn.execute( | |
| text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0") | |
| ) | |
| conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL")) | |
| conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at")) | |
| for table in _SCHEMA_TABLES: | |
| ft = _table_sql_name(table) | |
| with engine.begin() as conn: | |
| cols = _column_names_conn(conn, table) | |
| if "created_at" in cols and "created_time" not in cols: | |
| conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time")) | |
| cols = _column_names_conn(conn, table) | |
| if "is_delete" not in cols: | |
| conn.execute( | |
| text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0") | |
| ) | |
| cols = _column_names_conn(conn, table) | |
| if "updated_time" not in cols: | |
| conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL")) | |
| if "created_time" in cols: | |
| conn.execute( | |
| text( | |
| f"UPDATE {ft} SET updated_time = created_time " | |
| f"WHERE updated_time IS NULL" | |
| ) | |
| ) | |
| def init_db(): | |
| from models import Base as ModelBase # noqa: F401 | |
| from sqlalchemy import text | |
| with engine.connect() as conn: | |
| conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}")) | |
| conn.commit() | |
| ModelBase.metadata.create_all(bind=engine) | |
| _ensure_is_archived_column() | |
| _ensure_time_is_delete_columns() | |