Spaces:
Running
Running
Upload 13 files
Browse files- hf-vector-match-api/.gitattributes +35 -0
- hf-vector-match-api/.gitignore +5 -0
- hf-vector-match-api/Dockerfile +16 -0
- hf-vector-match-api/README.md +10 -0
- hf-vector-match-api/database.py +172 -0
- hf-vector-match-api/main.py +599 -0
- hf-vector-match-api/models.py +128 -0
- hf-vector-match-api/requirements.txt +13 -0
- hf-vector-match-api/schemas.py +146 -0
- hf-vector-match-api/services/__init__.py +0 -0
- hf-vector-match-api/services/embedding_service.py +231 -0
- hf-vector-match-api/services/excel_service.py +73 -0
- hf-vector-match-api/services/match_service.py +260 -0
hf-vector-match-api/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
hf-vector-match-api/.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.env
|
| 4 |
+
.env.local
|
| 5 |
+
data/uploads/
|
hf-vector-match-api/Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY --chown=user requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY --chown=user . /app
|
| 13 |
+
|
| 14 |
+
RUN mkdir -p data/uploads
|
| 15 |
+
|
| 16 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
hf-vector-match-api/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Vector Match Api
|
| 3 |
+
emoji: 😻
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
hf-vector-match-api/database.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from urllib.parse import quote_plus
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from sqlalchemy import create_engine
|
| 7 |
+
|
| 8 |
+
_env_dir = Path(__file__).resolve().parent
|
| 9 |
+
load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先
|
| 10 |
+
load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上)
|
| 11 |
+
from sqlalchemy.orm import sessionmaker, declarative_base
|
| 12 |
+
|
| 13 |
+
# 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname
|
| 14 |
+
_database_url = os.environ.get("DATABASE_URL", "").strip()
|
| 15 |
+
|
| 16 |
+
if _database_url:
|
| 17 |
+
SQLALCHEMY_DATABASE_URL = _database_url
|
| 18 |
+
else:
|
| 19 |
+
PG_HOST = os.environ.get("PG_HOST", "localhost")
|
| 20 |
+
PG_PORT = os.environ.get("PG_PORT", "5432")
|
| 21 |
+
PG_USER = os.environ.get("PG_USER", "postgres")
|
| 22 |
+
PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres")
|
| 23 |
+
PG_DB = os.environ.get("PG_DB", "vector_match")
|
| 24 |
+
_pw = quote_plus(PG_PASSWORD)
|
| 25 |
+
SQLALCHEMY_DATABASE_URL = (
|
| 26 |
+
f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match")
|
| 30 |
+
|
| 31 |
+
engine = create_engine(
|
| 32 |
+
SQLALCHEMY_DATABASE_URL,
|
| 33 |
+
pool_pre_ping=True,
|
| 34 |
+
pool_size=20,
|
| 35 |
+
pool_recycle=180,
|
| 36 |
+
pool_timeout=60,
|
| 37 |
+
max_overflow=10,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 每次连接自动切换到 vector_match schema
|
| 41 |
+
from sqlalchemy import event
|
| 42 |
+
|
| 43 |
+
@event.listens_for(engine, "connect")
|
| 44 |
+
def _set_search_path(dbapi_conn, connection_record):
|
| 45 |
+
cursor = dbapi_conn.cursor()
|
| 46 |
+
cursor.execute(f"SET search_path TO {PG_SCHEMA}, public")
|
| 47 |
+
cursor.close()
|
| 48 |
+
|
| 49 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 50 |
+
Base = declarative_base()
|
| 51 |
+
|
| 52 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
| 53 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
_SCHEMA_TABLES = (
|
| 56 |
+
"vector_match_task",
|
| 57 |
+
"vector_dataset",
|
| 58 |
+
"vector_data_row",
|
| 59 |
+
"vector_embedding",
|
| 60 |
+
"match_result",
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_db():
|
| 65 |
+
db = SessionLocal()
|
| 66 |
+
try:
|
| 67 |
+
yield db
|
| 68 |
+
except Exception:
|
| 69 |
+
db.rollback()
|
| 70 |
+
raise
|
| 71 |
+
finally:
|
| 72 |
+
db.close()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _table_sql_name(table: str) -> str:
|
| 76 |
+
if engine.dialect.name == "postgresql":
|
| 77 |
+
return f'"{PG_SCHEMA}"."{table}"'
|
| 78 |
+
return table
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _column_names_conn(conn, table: str) -> set:
|
| 82 |
+
"""
|
| 83 |
+
与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录,
|
| 84 |
+
否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。
|
| 85 |
+
"""
|
| 86 |
+
from sqlalchemy import inspect
|
| 87 |
+
|
| 88 |
+
insp = inspect(conn)
|
| 89 |
+
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
|
| 90 |
+
return {c["name"] for c in insp.get_columns(table, schema=schema)}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _ensure_is_archived_column():
|
| 94 |
+
"""旧库无 is_archived 时补列。"""
|
| 95 |
+
from sqlalchemy import inspect, text
|
| 96 |
+
|
| 97 |
+
insp = inspect(engine)
|
| 98 |
+
schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
|
| 99 |
+
try:
|
| 100 |
+
cols = insp.get_columns("vector_match_task", schema=schema)
|
| 101 |
+
except Exception:
|
| 102 |
+
return
|
| 103 |
+
if any(c["name"] == "is_archived" for c in cols):
|
| 104 |
+
return
|
| 105 |
+
ft = _table_sql_name("vector_match_task")
|
| 106 |
+
ddl = (
|
| 107 |
+
f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
|
| 108 |
+
if engine.dialect.name == "postgresql"
|
| 109 |
+
else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
|
| 110 |
+
)
|
| 111 |
+
with engine.begin() as conn:
|
| 112 |
+
conn.execute(text(ddl))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _ensure_time_is_delete_columns():
|
| 116 |
+
"""
|
| 117 |
+
统一:created_at→created_time;任务表 updated_at→updated_time;
|
| 118 |
+
各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。
|
| 119 |
+
"""
|
| 120 |
+
from sqlalchemy import text
|
| 121 |
+
|
| 122 |
+
ft_task = _table_sql_name("vector_match_task")
|
| 123 |
+
|
| 124 |
+
with engine.begin() as conn:
|
| 125 |
+
cols = _column_names_conn(conn, "vector_match_task")
|
| 126 |
+
if "created_at" in cols and "created_time" not in cols:
|
| 127 |
+
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time"))
|
| 128 |
+
if "updated_at" in cols and "updated_time" not in cols:
|
| 129 |
+
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time"))
|
| 130 |
+
if "is_deleted" in cols and "is_delete" not in cols:
|
| 131 |
+
conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete"))
|
| 132 |
+
cols = _column_names_conn(conn, "vector_match_task")
|
| 133 |
+
if "deleted_at" in cols:
|
| 134 |
+
if "is_delete" not in cols:
|
| 135 |
+
conn.execute(
|
| 136 |
+
text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
|
| 137 |
+
)
|
| 138 |
+
conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL"))
|
| 139 |
+
conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at"))
|
| 140 |
+
|
| 141 |
+
for table in _SCHEMA_TABLES:
|
| 142 |
+
ft = _table_sql_name(table)
|
| 143 |
+
with engine.begin() as conn:
|
| 144 |
+
cols = _column_names_conn(conn, table)
|
| 145 |
+
if "created_at" in cols and "created_time" not in cols:
|
| 146 |
+
conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time"))
|
| 147 |
+
cols = _column_names_conn(conn, table)
|
| 148 |
+
if "is_delete" not in cols:
|
| 149 |
+
conn.execute(
|
| 150 |
+
text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
|
| 151 |
+
)
|
| 152 |
+
cols = _column_names_conn(conn, table)
|
| 153 |
+
if "updated_time" not in cols:
|
| 154 |
+
conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL"))
|
| 155 |
+
if "created_time" in cols:
|
| 156 |
+
conn.execute(
|
| 157 |
+
text(
|
| 158 |
+
f"UPDATE {ft} SET updated_time = created_time "
|
| 159 |
+
f"WHERE updated_time IS NULL"
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def init_db():
|
| 165 |
+
from models import Base as ModelBase # noqa: F401
|
| 166 |
+
from sqlalchemy import text
|
| 167 |
+
with engine.connect() as conn:
|
| 168 |
+
conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}"))
|
| 169 |
+
conn.commit()
|
| 170 |
+
ModelBase.metadata.create_all(bind=engine)
|
| 171 |
+
_ensure_is_archived_column()
|
| 172 |
+
_ensure_time_is_delete_columns()
|
hf-vector-match-api/main.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import datetime
|
| 6 |
+
import httpx
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, UploadFile, File, Form, Depends, Query, HTTPException, BackgroundTasks
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from sqlalchemy.orm import Session
|
| 12 |
+
|
| 13 |
+
from database import get_db, init_db, SessionLocal
|
| 14 |
+
from models import (
|
| 15 |
+
VectorMatchTask, VectorDataset, VectorDataRow,
|
| 16 |
+
VectorEmbedding, MatchResult,
|
| 17 |
+
)
|
| 18 |
+
from schemas import (
|
| 19 |
+
TaskCreate, TaskDetail, TaskProgress, TaskListItem,
|
| 20 |
+
MatchResultItem, MatchResultPage, SourceWithCandidates, CandidateDetail,
|
| 21 |
+
UploadResponse, SettingItem, SettingsResponse, DatasetInfo,
|
| 22 |
+
)
|
| 23 |
+
from services.excel_service import save_upload_file, get_sheet_info, parse_excel_rows
|
| 24 |
+
from services.match_service import run_match_task
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
app = FastAPI(title="VectorMatch API", version="1.0.0")
|
| 28 |
+
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=["*"],
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
import logging, traceback
|
| 38 |
+
from starlette.requests import Request
|
| 39 |
+
from starlette.responses import JSONResponse
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger("uvicorn.error")
|
| 42 |
+
|
| 43 |
+
@app.exception_handler(Exception)
|
| 44 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 45 |
+
logger.error(f"Unhandled error on {request.method} {request.url}:\n{traceback.format_exc()}")
|
| 46 |
+
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ─── 健康状态缓存 ─────────────────────────────────────────────────────────
|
| 50 |
+
_health_cache = {
|
| 51 |
+
"result": {"embedding_ok": False, "reranker_ok": False, "embedding_model": "",
|
| 52 |
+
"reranker_model": "", "reranker_enabled": False, "has_api_key": False},
|
| 53 |
+
"updated_at": 0,
|
| 54 |
+
}
|
| 55 |
+
_HEALTH_TTL = 30 # 缓存有效期(秒)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
async def _do_health_check():
|
| 59 |
+
"""执行真正的 API 探活,更新缓存"""
|
| 60 |
+
import services.embedding_service as es
|
| 61 |
+
api_key = es.SILICONFLOW_API_KEY
|
| 62 |
+
result = {
|
| 63 |
+
"embedding_ok": False,
|
| 64 |
+
"reranker_ok": False,
|
| 65 |
+
"embedding_model": es.EMBEDDING_MODEL,
|
| 66 |
+
"reranker_model": es.RERANKER_MODEL,
|
| 67 |
+
"reranker_enabled": es.RERANKER_ENABLED,
|
| 68 |
+
"has_api_key": bool(api_key),
|
| 69 |
+
}
|
| 70 |
+
if api_key:
|
| 71 |
+
try:
|
| 72 |
+
async with httpx.AsyncClient(timeout=5.0, proxies={}) as client:
|
| 73 |
+
try:
|
| 74 |
+
emb_resp = await client.post(
|
| 75 |
+
"https://api.siliconflow.cn/v1/embeddings",
|
| 76 |
+
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
| 77 |
+
json={"model": es.EMBEDDING_MODEL, "input": ["ping"]},
|
| 78 |
+
)
|
| 79 |
+
result["embedding_ok"] = emb_resp.status_code == 200
|
| 80 |
+
except Exception:
|
| 81 |
+
pass
|
| 82 |
+
if es.RERANKER_ENABLED:
|
| 83 |
+
try:
|
| 84 |
+
rerank_resp = await client.post(
|
| 85 |
+
"https://api.siliconflow.cn/v1/rerank",
|
| 86 |
+
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
| 87 |
+
json={"model": es.RERANKER_MODEL, "query": "ping", "documents": ["pong"], "top_n": 1},
|
| 88 |
+
)
|
| 89 |
+
result["reranker_ok"] = rerank_resp.status_code == 200
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
_health_cache["result"] = result
|
| 95 |
+
_health_cache["updated_at"] = time.time()
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def _health_polling_loop():
|
| 100 |
+
"""后台定时探活循环"""
|
| 101 |
+
while True:
|
| 102 |
+
try:
|
| 103 |
+
await _do_health_check()
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
await asyncio.sleep(_HEALTH_TTL)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@app.on_event("startup")
|
| 110 |
+
async def startup():
|
| 111 |
+
init_db()
|
| 112 |
+
# 启动后台健康检查循环
|
| 113 |
+
asyncio.create_task(_health_polling_loop())
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ─── Upload Excel ───────────────────────────────────────────────────────────
|
| 117 |
+
@app.post("/api/upload", response_model=UploadResponse)
|
| 118 |
+
async def upload_excel(
|
| 119 |
+
file: UploadFile = File(...),
|
| 120 |
+
dataset_role: str = Form("source"),
|
| 121 |
+
db: Session = Depends(get_db),
|
| 122 |
+
):
|
| 123 |
+
content = await file.read()
|
| 124 |
+
filepath = save_upload_file(content, file.filename)
|
| 125 |
+
info = get_sheet_info(filepath)
|
| 126 |
+
|
| 127 |
+
dataset = VectorDataset(
|
| 128 |
+
name=file.filename,
|
| 129 |
+
file_name=file.filename,
|
| 130 |
+
dataset_role=dataset_role,
|
| 131 |
+
data_scope="task",
|
| 132 |
+
)
|
| 133 |
+
db.add(dataset)
|
| 134 |
+
db.commit()
|
| 135 |
+
db.refresh(dataset)
|
| 136 |
+
|
| 137 |
+
return UploadResponse(
|
| 138 |
+
dataset_id=dataset.id,
|
| 139 |
+
file_name=file.filename,
|
| 140 |
+
sheet_names=info["sheet_names"],
|
| 141 |
+
columns=info["columns"],
|
| 142 |
+
all_columns=info.get("all_columns", info["columns"]),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ─── Configure dataset (sheet, fields) ─────────────────────────────────────
|
| 147 |
+
@app.post("/api/dataset/{dataset_id}/configure")
|
| 148 |
+
def configure_dataset(
|
| 149 |
+
dataset_id: int,
|
| 150 |
+
sheet_name: str = Form(...),
|
| 151 |
+
vector_fields: str = Form(...),
|
| 152 |
+
db: Session = Depends(get_db),
|
| 153 |
+
):
|
| 154 |
+
dataset = db.query(VectorDataset).get(dataset_id)
|
| 155 |
+
if not dataset:
|
| 156 |
+
raise HTTPException(404, "Dataset not found")
|
| 157 |
+
|
| 158 |
+
dataset.sheet_name = sheet_name
|
| 159 |
+
dataset.vector_fields = vector_fields
|
| 160 |
+
db.commit()
|
| 161 |
+
|
| 162 |
+
fields = json.loads(vector_fields)
|
| 163 |
+
import os
|
| 164 |
+
filepath = os.path.join(
|
| 165 |
+
os.path.dirname(__file__), "data", "uploads", dataset.file_name
|
| 166 |
+
)
|
| 167 |
+
rows = parse_excel_rows(filepath, sheet_name, fields)
|
| 168 |
+
|
| 169 |
+
for row_data in rows:
|
| 170 |
+
dr = VectorDataRow(
|
| 171 |
+
dataset_id=dataset.id,
|
| 172 |
+
dataset_role=dataset.dataset_role,
|
| 173 |
+
data_scope=dataset.data_scope,
|
| 174 |
+
row_number=row_data["row_number"],
|
| 175 |
+
raw_text=row_data["raw_text"],
|
| 176 |
+
text_hash=row_data["text_hash"],
|
| 177 |
+
field_values=row_data["field_values"],
|
| 178 |
+
)
|
| 179 |
+
db.add(dr)
|
| 180 |
+
dataset.row_count = len(rows)
|
| 181 |
+
db.commit()
|
| 182 |
+
|
| 183 |
+
return {"status": "ok", "row_count": len(rows)}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ─── Get dataset info ──────────────────────────────────────────────────────
|
| 187 |
+
@app.get("/api/dataset/{dataset_id}", response_model=DatasetInfo)
|
| 188 |
+
def get_dataset(dataset_id: int, db: Session = Depends(get_db)):
|
| 189 |
+
dataset = db.query(VectorDataset).get(dataset_id)
|
| 190 |
+
if not dataset:
|
| 191 |
+
raise HTTPException(404, "Dataset not found")
|
| 192 |
+
return dataset
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ─── Create & start task ───────────────────────────────────────────────────
|
| 196 |
+
@app.post("/api/task", response_model=TaskDetail)
|
| 197 |
+
async def create_task(
|
| 198 |
+
background_tasks: BackgroundTasks,
|
| 199 |
+
source_dataset_id: int = Form(...),
|
| 200 |
+
target_dataset_id: int = Form(...),
|
| 201 |
+
match_mode: str = Form("two_file"),
|
| 202 |
+
top_k: int = Form(10),
|
| 203 |
+
rerank_top_k: int = Form(3),
|
| 204 |
+
min_threshold: float = Form(0.70),
|
| 205 |
+
candidate_scope: str = Form("current_task_target"),
|
| 206 |
+
db: Session = Depends(get_db),
|
| 207 |
+
):
|
| 208 |
+
now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
|
| 209 |
+
task_code = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
|
| 210 |
+
|
| 211 |
+
src = db.query(VectorDataset).get(source_dataset_id)
|
| 212 |
+
tgt = db.query(VectorDataset).get(target_dataset_id)
|
| 213 |
+
if not src or not tgt:
|
| 214 |
+
raise HTTPException(400, "Source or target dataset not found")
|
| 215 |
+
|
| 216 |
+
task = VectorMatchTask(
|
| 217 |
+
task_code=task_code,
|
| 218 |
+
match_mode=match_mode,
|
| 219 |
+
candidate_scope=candidate_scope,
|
| 220 |
+
source_dataset_id=source_dataset_id,
|
| 221 |
+
target_dataset_id=target_dataset_id,
|
| 222 |
+
top_k=top_k,
|
| 223 |
+
rerank_top_k=rerank_top_k,
|
| 224 |
+
min_threshold=min_threshold,
|
| 225 |
+
status="pending",
|
| 226 |
+
)
|
| 227 |
+
db.add(task)
|
| 228 |
+
db.commit()
|
| 229 |
+
db.refresh(task)
|
| 230 |
+
|
| 231 |
+
src.task_id = task.id
|
| 232 |
+
tgt.task_id = task.id
|
| 233 |
+
db.query(VectorDataRow).filter(VectorDataRow.dataset_id == src.id).update({"task_id": task.id})
|
| 234 |
+
db.query(VectorDataRow).filter(VectorDataRow.dataset_id == tgt.id).update({"task_id": task.id})
|
| 235 |
+
db.commit()
|
| 236 |
+
|
| 237 |
+
background_tasks.add_task(_run_task_in_background, task.id)
|
| 238 |
+
|
| 239 |
+
db.refresh(task)
|
| 240 |
+
return task
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _run_task_in_background(task_id: int):
|
| 244 |
+
loop = asyncio.new_event_loop()
|
| 245 |
+
asyncio.set_event_loop(loop)
|
| 246 |
+
loop.run_until_complete(run_match_task(task_id, SessionLocal))
|
| 247 |
+
loop.close()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _get_alive_task(db: Session, task_id: int) -> Optional[VectorMatchTask]:
|
| 251 |
+
"""未软删除的任务(is_delete=0)。"""
|
| 252 |
+
task = db.query(VectorMatchTask).get(task_id)
|
| 253 |
+
if not task or (task.is_delete or 0) == 1:
|
| 254 |
+
return None
|
| 255 |
+
return task
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ─── Task progress ─────────────────────────────────────────────────────────
|
| 259 |
+
@app.get("/api/task/{task_id}/progress", response_model=TaskProgress)
|
| 260 |
+
def get_task_progress(task_id: int, db: Session = Depends(get_db)):
|
| 261 |
+
task = _get_alive_task(db, task_id)
|
| 262 |
+
if not task:
|
| 263 |
+
raise HTTPException(404, "Task not found")
|
| 264 |
+
return task
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ─── Task detail ───────────────────────────────────────────────────────────
|
| 268 |
+
@app.get("/api/task/{task_id}", response_model=TaskDetail)
|
| 269 |
+
def get_task_detail(task_id: int, db: Session = Depends(get_db)):
|
| 270 |
+
task = _get_alive_task(db, task_id)
|
| 271 |
+
if not task:
|
| 272 |
+
raise HTTPException(404, "Task not found")
|
| 273 |
+
return task
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ─── Task list ─────────────────────────────────────────────────────────────
|
| 277 |
+
@app.get("/api/tasks", response_model=List[TaskListItem])
|
| 278 |
+
def list_tasks(
|
| 279 |
+
scope: str = Query("active", description="active=未归档, archived=仅归档, deleted=回收站"),
|
| 280 |
+
db: Session = Depends(get_db),
|
| 281 |
+
):
|
| 282 |
+
if scope not in ("active", "archived", "deleted"):
|
| 283 |
+
raise HTTPException(400, "scope 须为 active、archived 或 deleted")
|
| 284 |
+
q = db.query(VectorMatchTask)
|
| 285 |
+
if scope == "deleted":
|
| 286 |
+
q = q.filter(VectorMatchTask.is_delete == 1)
|
| 287 |
+
else:
|
| 288 |
+
q = q.filter(VectorMatchTask.is_delete == 0)
|
| 289 |
+
if scope == "archived":
|
| 290 |
+
q = q.filter(VectorMatchTask.is_archived == 1)
|
| 291 |
+
else:
|
| 292 |
+
q = q.filter(VectorMatchTask.is_archived == 0)
|
| 293 |
+
tasks = q.order_by(VectorMatchTask.created_time.desc()).all()
|
| 294 |
+
result = []
|
| 295 |
+
for t in tasks:
|
| 296 |
+
src_name = t.source_dataset.name if t.source_dataset else None
|
| 297 |
+
tgt_name = t.target_dataset.name if t.target_dataset else None
|
| 298 |
+
result.append(TaskListItem(
|
| 299 |
+
id=t.id,
|
| 300 |
+
task_code=t.task_code,
|
| 301 |
+
match_mode=t.match_mode,
|
| 302 |
+
candidate_scope=t.candidate_scope,
|
| 303 |
+
source_dataset_name=src_name,
|
| 304 |
+
target_dataset_name=tgt_name,
|
| 305 |
+
status=t.status,
|
| 306 |
+
is_archived=t.is_archived or 0,
|
| 307 |
+
is_delete=t.is_delete or 0,
|
| 308 |
+
created_time=t.created_time,
|
| 309 |
+
))
|
| 310 |
+
return result
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@app.post("/api/task/{task_id}/archive")
|
| 314 |
+
def archive_task(task_id: int, db: Session = Depends(get_db)):
|
| 315 |
+
task = _get_alive_task(db, task_id)
|
| 316 |
+
if not task:
|
| 317 |
+
raise HTTPException(404, "Task not found")
|
| 318 |
+
task.is_archived = 1
|
| 319 |
+
db.commit()
|
| 320 |
+
return {"status": "ok"}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@app.post("/api/task/{task_id}/unarchive")
|
| 324 |
+
def unarchive_task(task_id: int, db: Session = Depends(get_db)):
|
| 325 |
+
task = _get_alive_task(db, task_id)
|
| 326 |
+
if not task:
|
| 327 |
+
raise HTTPException(404, "Task not found")
|
| 328 |
+
task.is_archived = 0
|
| 329 |
+
db.commit()
|
| 330 |
+
return {"status": "ok"}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@app.delete("/api/task/{task_id}")
|
| 334 |
+
def delete_task(task_id: int, db: Session = Depends(get_db)):
|
| 335 |
+
"""软删除:is_delete=1,数据仍保留在库中。"""
|
| 336 |
+
task = _get_alive_task(db, task_id)
|
| 337 |
+
if not task:
|
| 338 |
+
raise HTTPException(404, "Task not found")
|
| 339 |
+
task.is_delete = 1
|
| 340 |
+
db.commit()
|
| 341 |
+
return {"status": "ok"}
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@app.post("/api/task/{task_id}/restore")
|
| 345 |
+
def restore_task(task_id: int, db: Session = Depends(get_db)):
|
| 346 |
+
"""从回收站恢复。"""
|
| 347 |
+
task = db.query(VectorMatchTask).get(task_id)
|
| 348 |
+
if not task or (task.is_delete or 0) != 1:
|
| 349 |
+
raise HTTPException(404, "Task not found or not deleted")
|
| 350 |
+
task.is_delete = 0
|
| 351 |
+
db.commit()
|
| 352 |
+
return {"status": "ok"}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ─── Match results ─────────────────────────────────────────────────────────
|
| 356 |
+
@app.get("/api/task/{task_id}/results", response_model=MatchResultPage)
|
| 357 |
+
def get_task_results(
|
| 358 |
+
task_id: int,
|
| 359 |
+
page: int = Query(1, ge=1),
|
| 360 |
+
page_size: int = Query(20, ge=1, le=100),
|
| 361 |
+
search: Optional[str] = None,
|
| 362 |
+
level: Optional[str] = None,
|
| 363 |
+
sort: str = "score_desc",
|
| 364 |
+
db: Session = Depends(get_db),
|
| 365 |
+
):
|
| 366 |
+
if not _get_alive_task(db, task_id):
|
| 367 |
+
raise HTTPException(404, "Task not found")
|
| 368 |
+
query = (
|
| 369 |
+
db.query(MatchResult)
|
| 370 |
+
.filter(MatchResult.task_id == task_id, MatchResult.rank == 1)
|
| 371 |
+
)
|
| 372 |
+
if level and level != "all":
|
| 373 |
+
query = query.filter(MatchResult.match_level == level)
|
| 374 |
+
|
| 375 |
+
if sort == "score_desc":
|
| 376 |
+
query = query.order_by(MatchResult.similarity_score.desc())
|
| 377 |
+
elif sort == "score_asc":
|
| 378 |
+
query = query.order_by(MatchResult.similarity_score.asc())
|
| 379 |
+
else:
|
| 380 |
+
query = query.order_by(MatchResult.source_row_id)
|
| 381 |
+
|
| 382 |
+
total = query.count()
|
| 383 |
+
results = query.offset((page - 1) * page_size).limit(page_size).all()
|
| 384 |
+
|
| 385 |
+
items = []
|
| 386 |
+
for r in results:
|
| 387 |
+
src_row = db.query(VectorDataRow).get(r.source_row_id)
|
| 388 |
+
tgt_row = db.query(VectorDataRow).get(r.target_row_id)
|
| 389 |
+
if search:
|
| 390 |
+
if search.lower() not in (src_row.raw_text or "").lower() and \
|
| 391 |
+
search.lower() not in (tgt_row.raw_text or "").lower():
|
| 392 |
+
continue
|
| 393 |
+
items.append(MatchResultItem(
|
| 394 |
+
id=r.id,
|
| 395 |
+
source_row_id=r.source_row_id,
|
| 396 |
+
source_row_number=src_row.row_number if src_row else 0,
|
| 397 |
+
source_text=src_row.raw_text if src_row else "",
|
| 398 |
+
target_text=tgt_row.raw_text if tgt_row else "",
|
| 399 |
+
similarity_score=r.similarity_score,
|
| 400 |
+
rerank_score=r.rerank_score,
|
| 401 |
+
match_level=r.match_level or "",
|
| 402 |
+
candidate_scope=r.candidate_scope,
|
| 403 |
+
is_confirmed=r.is_confirmed,
|
| 404 |
+
))
|
| 405 |
+
|
| 406 |
+
return MatchResultPage(items=items, total=total, page=page, page_size=page_size)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# ─── Candidate details for a source row ────────────────────────────────────
|
| 410 |
+
@app.get("/api/task/{task_id}/candidates/{source_row_id}", response_model=SourceWithCandidates)
|
| 411 |
+
def get_candidates(task_id: int, source_row_id: int, db: Session = Depends(get_db)):
|
| 412 |
+
if not _get_alive_task(db, task_id):
|
| 413 |
+
raise HTTPException(404, "Task not found")
|
| 414 |
+
src_row = db.query(VectorDataRow).get(source_row_id)
|
| 415 |
+
if not src_row:
|
| 416 |
+
raise HTTPException(404, "Source row not found")
|
| 417 |
+
|
| 418 |
+
results = (
|
| 419 |
+
db.query(MatchResult)
|
| 420 |
+
.filter(MatchResult.task_id == task_id, MatchResult.source_row_id == source_row_id)
|
| 421 |
+
.order_by(MatchResult.rank)
|
| 422 |
+
.all()
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
candidates = []
|
| 426 |
+
for r in results:
|
| 427 |
+
tgt_row = db.query(VectorDataRow).get(r.target_row_id)
|
| 428 |
+
candidates.append(CandidateDetail(
|
| 429 |
+
rank=r.rank,
|
| 430 |
+
rerank_rank=r.rerank_rank,
|
| 431 |
+
target_row_id=r.target_row_id,
|
| 432 |
+
target_text=tgt_row.raw_text if tgt_row else "",
|
| 433 |
+
similarity_score=r.similarity_score,
|
| 434 |
+
rerank_score=r.rerank_score,
|
| 435 |
+
match_level=r.match_level or "",
|
| 436 |
+
dataset_role="target",
|
| 437 |
+
candidate_scope=r.candidate_scope,
|
| 438 |
+
data_row_id=tgt_row.id if tgt_row else 0,
|
| 439 |
+
is_confirmed=r.is_confirmed,
|
| 440 |
+
))
|
| 441 |
+
|
| 442 |
+
return SourceWithCandidates(
|
| 443 |
+
source_row_id=src_row.id,
|
| 444 |
+
source_text=src_row.raw_text,
|
| 445 |
+
source_row_number=src_row.row_number,
|
| 446 |
+
dataset_role=src_row.dataset_role,
|
| 447 |
+
data_row_id=src_row.id,
|
| 448 |
+
candidates=candidates,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ─── Confirm match ─────────────────────────────────────────────────────────
|
| 453 |
+
@app.post("/api/result/{result_id}/confirm")
|
| 454 |
+
def confirm_match(result_id: int, db: Session = Depends(get_db)):
|
| 455 |
+
result = db.query(MatchResult).get(result_id)
|
| 456 |
+
if not result:
|
| 457 |
+
raise HTTPException(404, "Result not found")
|
| 458 |
+
result.is_confirmed = 1
|
| 459 |
+
db.commit()
|
| 460 |
+
return {"status": "ok"}
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
@app.post("/api/result/{result_id}/ignore")
|
| 464 |
+
def ignore_match(result_id: int, db: Session = Depends(get_db)):
|
| 465 |
+
result = db.query(MatchResult).get(result_id)
|
| 466 |
+
if not result:
|
| 467 |
+
raise HTTPException(404, "Result not found")
|
| 468 |
+
result.is_confirmed = -1
|
| 469 |
+
db.commit()
|
| 470 |
+
return {"status": "ok"}
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# ─── Settings (read/write .env) ────────────────────────────────────────────
|
| 474 |
+
_backend_dir = os.path.dirname(os.path.abspath(__file__))
|
| 475 |
+
_env_local = os.path.join(_backend_dir, ".env.local")
|
| 476 |
+
ENV_PATH = _env_local if os.path.exists(_env_local) else os.path.join(_backend_dir, ".env")
|
| 477 |
+
|
| 478 |
+
def _read_env() -> dict:
|
| 479 |
+
result = {}
|
| 480 |
+
if os.path.exists(ENV_PATH):
|
| 481 |
+
with open(ENV_PATH, "r", encoding="utf-8") as f:
|
| 482 |
+
for line in f:
|
| 483 |
+
line = line.strip()
|
| 484 |
+
if line and not line.startswith("#") and "=" in line:
|
| 485 |
+
k, v = line.split("=", 1)
|
| 486 |
+
result[k.strip()] = v.strip()
|
| 487 |
+
return result
|
| 488 |
+
|
| 489 |
+
def _write_env(settings: dict):
|
| 490 |
+
with open(ENV_PATH, "w", encoding="utf-8") as f:
|
| 491 |
+
for k, v in settings.items():
|
| 492 |
+
f.write(f"{k}={v}\n")
|
| 493 |
+
|
| 494 |
+
@app.get("/api/settings", response_model=SettingsResponse)
|
| 495 |
+
def get_settings():
|
| 496 |
+
return SettingsResponse(settings=_read_env())
|
| 497 |
+
|
| 498 |
+
@app.post("/api/settings")
|
| 499 |
+
async def update_settings(items: List[SettingItem]):
|
| 500 |
+
current = _read_env()
|
| 501 |
+
for item in items:
|
| 502 |
+
current[item.key] = item.value
|
| 503 |
+
_write_env(current)
|
| 504 |
+
# 保存后自动重载环境变量,无需手动重启
|
| 505 |
+
from dotenv import load_dotenv
|
| 506 |
+
load_dotenv(ENV_PATH, override=True)
|
| 507 |
+
# 同步更新 embedding_service 模块中的配置常量
|
| 508 |
+
import services.embedding_service as es
|
| 509 |
+
es.SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "")
|
| 510 |
+
es.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-m3")
|
| 511 |
+
es.EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "1024"))
|
| 512 |
+
es.RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "Qwen/Qwen3-VL-Reranker-8B")
|
| 513 |
+
es.RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true"
|
| 514 |
+
# 立即刷新健康缓存,前端下次请求即可拿到最新状态
|
| 515 |
+
await _do_health_check()
|
| 516 |
+
return {"status": "ok", "message": "已保存��配置已实时生效"}
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
# ─── 健康检查(返回后端缓存,秒级响应)────────────────────────────────────────
|
| 520 |
+
@app.get("/api/health")
|
| 521 |
+
async def health_check(force: bool = False):
|
| 522 |
+
"""返回缓存的健康状态,force=true 时立即刷新"""
|
| 523 |
+
if force or time.time() - _health_cache["updated_at"] > _HEALTH_TTL:
|
| 524 |
+
await _do_health_check()
|
| 525 |
+
return _health_cache["result"]
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# ─── Export results ────────────────────────────────────────────────────────
|
| 529 |
+
@app.get("/api/task/{task_id}/export")
|
| 530 |
+
def export_results(task_id: int, db: Session = Depends(get_db)):
|
| 531 |
+
import io
|
| 532 |
+
import openpyxl
|
| 533 |
+
from openpyxl.styles import Font, Alignment, PatternFill
|
| 534 |
+
from fastapi.responses import StreamingResponse
|
| 535 |
+
|
| 536 |
+
task = _get_alive_task(db, task_id)
|
| 537 |
+
if not task:
|
| 538 |
+
raise HTTPException(404, "Task not found")
|
| 539 |
+
|
| 540 |
+
results = (
|
| 541 |
+
db.query(MatchResult)
|
| 542 |
+
.filter(MatchResult.task_id == task_id)
|
| 543 |
+
.order_by(MatchResult.source_row_id, MatchResult.rank)
|
| 544 |
+
.all()
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
from openpyxl.styles import Font, PatternFill, Alignment
|
| 548 |
+
|
| 549 |
+
wb = openpyxl.Workbook()
|
| 550 |
+
ws = wb.active
|
| 551 |
+
ws.title = "匹配结果"
|
| 552 |
+
|
| 553 |
+
headers = ["源行号", "源数据内容", "候选排名", "目标候选内容", "相似度(%)", "精排分", "匹配等级", "候选来源"]
|
| 554 |
+
ws.append(headers)
|
| 555 |
+
|
| 556 |
+
# Header styling
|
| 557 |
+
header_font = Font(bold=True, color="FFFFFF")
|
| 558 |
+
header_fill = PatternFill(start_color="1F4E79", end_color="1F4E79", fill_type="solid")
|
| 559 |
+
for cell in ws[1]:
|
| 560 |
+
cell.font = header_font
|
| 561 |
+
cell.fill = header_fill
|
| 562 |
+
cell.alignment = Alignment(horizontal="center", vertical="center")
|
| 563 |
+
|
| 564 |
+
level_map = {"high": "高度匹配", "possible": "可能匹配", "low_confidence": "低置信", "no_match": "不匹配"}
|
| 565 |
+
scope_map = {"current_task_target": "目标候选集", "history": "历史数据", "standard": "标准库"}
|
| 566 |
+
|
| 567 |
+
for r in results:
|
| 568 |
+
src = db.query(VectorDataRow).get(r.source_row_id)
|
| 569 |
+
tgt = db.query(VectorDataRow).get(r.target_row_id)
|
| 570 |
+
ws.append([
|
| 571 |
+
src.row_number if src else "",
|
| 572 |
+
src.raw_text if src else "",
|
| 573 |
+
r.rank,
|
| 574 |
+
tgt.raw_text if tgt else "",
|
| 575 |
+
round(r.similarity_score * 100, 2),
|
| 576 |
+
round(r.rerank_score, 4) if r.rerank_score is not None else "",
|
| 577 |
+
level_map.get(r.match_level, r.match_level),
|
| 578 |
+
scope_map.get(r.candidate_scope, r.candidate_scope or ""),
|
| 579 |
+
])
|
| 580 |
+
|
| 581 |
+
# Column widths
|
| 582 |
+
col_widths = [8, 40, 10, 40, 12, 12, 12, 14]
|
| 583 |
+
for i, w in enumerate(col_widths, 1):
|
| 584 |
+
ws.column_dimensions[chr(64 + i)].width = w
|
| 585 |
+
|
| 586 |
+
output = io.BytesIO()
|
| 587 |
+
wb.save(output)
|
| 588 |
+
output.seek(0)
|
| 589 |
+
|
| 590 |
+
return StreamingResponse(
|
| 591 |
+
output,
|
| 592 |
+
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
| 593 |
+
headers={"Content-Disposition": f"attachment; filename=match_result_{task.task_code}.xlsx"},
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
if __name__ == "__main__":
|
| 598 |
+
import uvicorn
|
| 599 |
+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
hf-vector-match-api/models.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, Text, LargeBinary, ForeignKey
|
| 3 |
+
from sqlalchemy.orm import relationship
|
| 4 |
+
from database import Base
|
| 5 |
+
|
| 6 |
+
_TZ_BEIJING = datetime.timezone(datetime.timedelta(hours=8))
|
| 7 |
+
|
| 8 |
+
def _now_beijing():
|
| 9 |
+
return datetime.datetime.now(_TZ_BEIJING).replace(tzinfo=None)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VectorMatchTask(Base):
|
| 13 |
+
__tablename__ = "vector_match_task"
|
| 14 |
+
__table_args__ = {"comment": "向量匹配任务表"}
|
| 15 |
+
|
| 16 |
+
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
|
| 17 |
+
task_code = Column(String(30), unique=True, nullable=False, index=True, comment="任务编号,格式:YYYYMMDDHHMMSSmmm")
|
| 18 |
+
match_mode = Column(String(50), nullable=False, default="two_file", comment="匹配模式:two_file/history/standard")
|
| 19 |
+
candidate_scope = Column(String(50), nullable=False, default="current_task_target", comment="候选范围:current_task_target/history/standard")
|
| 20 |
+
source_dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=True, comment="源数据集ID")
|
| 21 |
+
target_dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=True, comment="目标候选集ID")
|
| 22 |
+
top_k = Column(Integer, default=10, comment="每条源数据保留的Top-K候选数")
|
| 23 |
+
rerank_top_k = Column(Integer, default=3, comment="Reranker重排序后保留的Top-K数")
|
| 24 |
+
min_threshold = Column(Float, default=0.70, comment="最低相似度阈值")
|
| 25 |
+
status = Column(String(20), default="pending", comment="任务状态:pending/running/completed/failed")
|
| 26 |
+
source_row_count = Column(Integer, default=0, comment="源数据行数")
|
| 27 |
+
target_row_count = Column(Integer, default=0, comment="目标候选行数")
|
| 28 |
+
high_match_count = Column(Integer, default=0, comment="高度匹配数量(score>=0.90)")
|
| 29 |
+
low_confidence_count = Column(Integer, default=0, comment="低置信数量(score<0.70)")
|
| 30 |
+
reused_vectors = Column(Integer, default=0, comment="通过text_hash复用的向量数")
|
| 31 |
+
new_vectors = Column(Integer, default=0, comment="新生成的向量数")
|
| 32 |
+
progress_parse_source = Column(Integer, default=0, comment="解析源数据集进度(0-100)")
|
| 33 |
+
progress_parse_target = Column(Integer, default=0, comment="解析目标候选集进度(0-100)")
|
| 34 |
+
progress_vectorize = Column(Integer, default=0, comment="向量化进度(0-100)")
|
| 35 |
+
progress_load_candidates = Column(Integer, default=0, comment="加载候选范围进度(0-100)")
|
| 36 |
+
progress_similarity = Column(Integer, default=0, comment="相似度计算进度(0-100)")
|
| 37 |
+
progress_rerank = Column(Integer, default=0, comment="Reranker重排序进度(0-100)")
|
| 38 |
+
progress_save_results = Column(Integer, default=0, comment="保存结果进度(0-100)")
|
| 39 |
+
created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
|
| 40 |
+
updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
|
| 41 |
+
is_archived = Column(Integer, default=0, comment="是否归档:0=未归档,1=已归档")
|
| 42 |
+
is_delete = Column(Integer, default=0, comment="是否删除:0=未删除,1=已删除")
|
| 43 |
+
|
| 44 |
+
source_dataset = relationship("VectorDataset", foreign_keys=[source_dataset_id])
|
| 45 |
+
target_dataset = relationship("VectorDataset", foreign_keys=[target_dataset_id])
|
| 46 |
+
results = relationship("MatchResult", back_populates="task")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class VectorDataset(Base):
|
| 50 |
+
__tablename__ = "vector_dataset"
|
| 51 |
+
__table_args__ = {"comment": "向量数据集表(上传或逻辑数据集)"}
|
| 52 |
+
|
| 53 |
+
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
|
| 54 |
+
task_id = Column(Integer, ForeignKey("vector_match_task.id"), nullable=True, comment="所属任务ID")
|
| 55 |
+
name = Column(String(255), nullable=False, comment="数据集名称")
|
| 56 |
+
file_name = Column(String(255), nullable=True, comment="上传文件名")
|
| 57 |
+
sheet_name = Column(String(100), nullable=True, comment="Excel工作表名")
|
| 58 |
+
dataset_role = Column(String(20), nullable=False, comment="数据集角色:source(源)/target(目标候选)")
|
| 59 |
+
data_scope = Column(String(20), default="task", comment="数据范围:task/history/standard")
|
| 60 |
+
vector_fields = Column(Text, nullable=True, comment="参与向量化的字段列表(JSON)")
|
| 61 |
+
row_count = Column(Integer, default=0, comment="数据行数")
|
| 62 |
+
is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
|
| 63 |
+
created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
|
| 64 |
+
updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
|
| 65 |
+
|
| 66 |
+
rows = relationship("VectorDataRow", back_populates="dataset")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class VectorDataRow(Base):
|
| 70 |
+
__tablename__ = "vector_data_row"
|
| 71 |
+
__table_args__ = {"comment": "向量数据行表(单行物料/申报项等)"}
|
| 72 |
+
|
| 73 |
+
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
|
| 74 |
+
dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=False, index=True, comment="所属数据集ID")
|
| 75 |
+
task_id = Column(Integer, nullable=True, index=True, comment="所属任务ID")
|
| 76 |
+
dataset_role = Column(String(20), nullable=False, index=True, comment="数据集角色:source/target")
|
| 77 |
+
data_scope = Column(String(20), default="task", index=True, comment="数据范围:task/history/standard")
|
| 78 |
+
row_number = Column(Integer, nullable=False, comment="Excel中的行号")
|
| 79 |
+
raw_text = Column(Text, nullable=False, comment="拼接后的原始文本")
|
| 80 |
+
text_hash = Column(String(64), nullable=True, index=True, comment="文本SHA256哈希,用于向量复用")
|
| 81 |
+
field_values = Column(Text, nullable=True, comment="各字段值(JSON)")
|
| 82 |
+
is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
|
| 83 |
+
created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
|
| 84 |
+
updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
|
| 85 |
+
|
| 86 |
+
dataset = relationship("VectorDataset", back_populates="rows")
|
| 87 |
+
embedding = relationship("VectorEmbedding", back_populates="data_row", uselist=False)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VectorEmbedding(Base):
|
| 91 |
+
__tablename__ = "vector_embedding"
|
| 92 |
+
__table_args__ = {"comment": "向量嵌入表(与数据行一对一)"}
|
| 93 |
+
|
| 94 |
+
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
|
| 95 |
+
data_row_id = Column(Integer, ForeignKey("vector_data_row.id"), unique=True, nullable=False, index=True, comment="关联 vector_data_row.id")
|
| 96 |
+
text_hash = Column(String(64), nullable=False, index=True, comment="与行一致的文本哈希")
|
| 97 |
+
embedding = Column(LargeBinary(length=65536), nullable=False, comment="float32 数组二进制存储")
|
| 98 |
+
model_name = Column(String(100), nullable=True, comment="生成向量所用模型名")
|
| 99 |
+
dimension = Column(Integer, nullable=True, comment="向量维度")
|
| 100 |
+
is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
|
| 101 |
+
created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
|
| 102 |
+
updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
|
| 103 |
+
|
| 104 |
+
data_row = relationship("VectorDataRow", back_populates="embedding")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MatchResult(Base):
|
| 108 |
+
__tablename__ = "match_result"
|
| 109 |
+
__table_args__ = {"comment": "匹配结果表(源行与候选的关联及得分)"}
|
| 110 |
+
|
| 111 |
+
id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
|
| 112 |
+
task_id = Column(Integer, ForeignKey("vector_match_task.id"), nullable=False, index=True, comment="所属任务ID")
|
| 113 |
+
source_row_id = Column(Integer, ForeignKey("vector_data_row.id"), nullable=False, comment="源数据行ID")
|
| 114 |
+
target_row_id = Column(Integer, ForeignKey("vector_data_row.id"), nullable=False, comment="目标候选行ID")
|
| 115 |
+
similarity_score = Column(Float, nullable=False, comment="余弦相似度分数(0-1)")
|
| 116 |
+
rerank_score = Column(Float, nullable=True, comment="Reranker精排分数,越高越相关")
|
| 117 |
+
rank = Column(Integer, nullable=False, comment="排名(1=最相似)")
|
| 118 |
+
rerank_rank = Column(Integer, nullable=True, comment="Reranker重排后的排名")
|
| 119 |
+
candidate_scope = Column(String(50), nullable=True, comment="候选来源范围")
|
| 120 |
+
match_level = Column(String(20), nullable=True, comment="匹配等级:high/possible/low_confidence/no_match")
|
| 121 |
+
is_confirmed = Column(Integer, default=0, comment="是否已人工确认:0=未确认,1=已确认,-1=已忽略")
|
| 122 |
+
is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
|
| 123 |
+
created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
|
| 124 |
+
updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
|
| 125 |
+
|
| 126 |
+
task = relationship("VectorMatchTask", back_populates="results")
|
| 127 |
+
source_row = relationship("VectorDataRow", foreign_keys=[source_row_id])
|
| 128 |
+
target_row = relationship("VectorDataRow", foreign_keys=[target_row_id])
|
hf-vector-match-api/requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn==0.30.6
|
| 3 |
+
sqlalchemy==2.0.35
|
| 4 |
+
python-multipart==0.0.12
|
| 5 |
+
openpyxl==3.1.5
|
| 6 |
+
numpy==1.24.4
|
| 7 |
+
pandas==2.0.3
|
| 8 |
+
httpx==0.26.0
|
| 9 |
+
pydantic==2.9.2
|
| 10 |
+
psycopg2-binary==2.9.9
|
| 11 |
+
pymysql==1.1.1
|
| 12 |
+
cryptography==43.0.1
|
| 13 |
+
python-dotenv==1.0.1
|
hf-vector-match-api/schemas.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DatasetInfo(BaseModel):
|
| 8 |
+
id: int
|
| 9 |
+
name: str
|
| 10 |
+
file_name: Optional[str] = None
|
| 11 |
+
sheet_name: Optional[str] = None
|
| 12 |
+
dataset_role: str
|
| 13 |
+
data_scope: str
|
| 14 |
+
vector_fields: Optional[str] = None
|
| 15 |
+
row_count: int = 0
|
| 16 |
+
|
| 17 |
+
class Config:
|
| 18 |
+
from_attributes = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TaskCreate(BaseModel):
|
| 22 |
+
match_mode: str = "two_file"
|
| 23 |
+
top_k: int = 3
|
| 24 |
+
min_threshold: float = 0.70
|
| 25 |
+
candidate_scope: str = "current_task_target"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TaskProgress(BaseModel):
|
| 29 |
+
id: int
|
| 30 |
+
task_code: str
|
| 31 |
+
status: str
|
| 32 |
+
source_row_count: int
|
| 33 |
+
target_row_count: int
|
| 34 |
+
reused_vectors: int
|
| 35 |
+
new_vectors: int
|
| 36 |
+
progress_parse_source: int
|
| 37 |
+
progress_parse_target: int
|
| 38 |
+
progress_vectorize: int
|
| 39 |
+
progress_load_candidates: int
|
| 40 |
+
progress_similarity: int
|
| 41 |
+
progress_rerank: int = 0
|
| 42 |
+
progress_save_results: int
|
| 43 |
+
|
| 44 |
+
class Config:
|
| 45 |
+
from_attributes = True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TaskDetail(BaseModel):
|
| 49 |
+
id: int
|
| 50 |
+
task_code: str
|
| 51 |
+
match_mode: str
|
| 52 |
+
candidate_scope: str
|
| 53 |
+
top_k: int
|
| 54 |
+
min_threshold: float
|
| 55 |
+
status: str
|
| 56 |
+
source_row_count: int
|
| 57 |
+
target_row_count: int
|
| 58 |
+
high_match_count: int
|
| 59 |
+
low_confidence_count: int
|
| 60 |
+
reused_vectors: int
|
| 61 |
+
new_vectors: int
|
| 62 |
+
source_dataset: Optional[DatasetInfo] = None
|
| 63 |
+
target_dataset: Optional[DatasetInfo] = None
|
| 64 |
+
created_time: Optional[datetime] = None
|
| 65 |
+
updated_time: Optional[datetime] = None
|
| 66 |
+
|
| 67 |
+
class Config:
|
| 68 |
+
from_attributes = True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TaskListItem(BaseModel):
|
| 72 |
+
id: int
|
| 73 |
+
task_code: str
|
| 74 |
+
match_mode: str
|
| 75 |
+
candidate_scope: str
|
| 76 |
+
source_dataset_name: Optional[str] = None
|
| 77 |
+
target_dataset_name: Optional[str] = None
|
| 78 |
+
status: str
|
| 79 |
+
is_archived: int = 0
|
| 80 |
+
is_delete: int = 0
|
| 81 |
+
created_time: Optional[datetime] = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MatchResultItem(BaseModel):
|
| 85 |
+
id: int
|
| 86 |
+
source_row_id: int
|
| 87 |
+
source_row_number: int
|
| 88 |
+
source_text: str
|
| 89 |
+
target_text: str
|
| 90 |
+
similarity_score: float
|
| 91 |
+
rerank_score: Optional[float] = None
|
| 92 |
+
match_level: str
|
| 93 |
+
candidate_scope: Optional[str] = None
|
| 94 |
+
is_confirmed: int = 0
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class MatchResultPage(BaseModel):
|
| 98 |
+
items: List[MatchResultItem]
|
| 99 |
+
total: int
|
| 100 |
+
page: int
|
| 101 |
+
page_size: int
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class CandidateDetail(BaseModel):
|
| 105 |
+
rank: int
|
| 106 |
+
rerank_rank: Optional[int] = None
|
| 107 |
+
target_row_id: int
|
| 108 |
+
target_text: str
|
| 109 |
+
similarity_score: float
|
| 110 |
+
rerank_score: Optional[float] = None
|
| 111 |
+
match_level: str
|
| 112 |
+
dataset_role: str
|
| 113 |
+
candidate_scope: Optional[str] = None
|
| 114 |
+
data_row_id: int
|
| 115 |
+
is_confirmed: int = 0
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SourceWithCandidates(BaseModel):
|
| 119 |
+
source_row_id: int
|
| 120 |
+
source_text: str
|
| 121 |
+
source_row_number: int
|
| 122 |
+
dataset_role: str
|
| 123 |
+
data_row_id: int
|
| 124 |
+
candidates: List[CandidateDetail]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class SheetInfo(BaseModel):
|
| 128 |
+
sheet_names: List[str]
|
| 129 |
+
columns: dict
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class UploadResponse(BaseModel):
|
| 133 |
+
dataset_id: int
|
| 134 |
+
file_name: str
|
| 135 |
+
sheet_names: List[str]
|
| 136 |
+
columns: dict
|
| 137 |
+
all_columns: dict = {}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SettingItem(BaseModel):
|
| 141 |
+
key: str
|
| 142 |
+
value: str
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SettingsResponse(BaseModel):
|
| 146 |
+
settings: dict
|
hf-vector-match-api/services/__init__.py
ADDED
|
File without changes
|
hf-vector-match-api/services/embedding_service.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import hashlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Optional, Dict
|
| 6 |
+
import httpx
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
|
| 10 |
+
|
| 11 |
+
EMBEDDING_API_URL = os.environ.get("EMBEDDING_API_URL", "https://api.siliconflow.cn/v1/embeddings")
|
| 12 |
+
# EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "Qwen/Qwen3-VL-Embedding-8B")
|
| 13 |
+
# EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "4096"))
|
| 14 |
+
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-m3")
|
| 15 |
+
EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "1024"))
|
| 16 |
+
EMBEDDING_PROVIDER = os.environ.get("EMBEDDING_PROVIDER", "siliconflow")
|
| 17 |
+
SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "")
|
| 18 |
+
|
| 19 |
+
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "Qwen/Qwen3-VL-Reranker-8B")
|
| 20 |
+
RERANKER_API_URL = os.environ.get("RERANKER_API_URL", "https://api.siliconflow.cn/v1/rerank")
|
| 21 |
+
RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true"
|
| 22 |
+
|
| 23 |
+
# Bypass system proxy for SiliconFlow API calls
|
| 24 |
+
if "NO_PROXY" not in os.environ:
|
| 25 |
+
os.environ["NO_PROXY"] = "api.siliconflow.cn"
|
| 26 |
+
elif "siliconflow" not in os.environ.get("NO_PROXY", ""):
|
| 27 |
+
os.environ["NO_PROXY"] = os.environ["NO_PROXY"] + ",api.siliconflow.cn"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _build_simple_embedding(text: str, dim: int = 768) -> np.ndarray:
|
| 31 |
+
"""Fallback: deterministic pseudo-embedding based on character hashing.
|
| 32 |
+
Only for testing when no real embedding API is available."""
|
| 33 |
+
h = hashlib.sha512(text.encode("utf-8")).digest()
|
| 34 |
+
seed = int.from_bytes(h[:4], "big")
|
| 35 |
+
rng = np.random.RandomState(seed)
|
| 36 |
+
vec = rng.randn(dim).astype(np.float32)
|
| 37 |
+
norm = np.linalg.norm(vec)
|
| 38 |
+
if norm > 0:
|
| 39 |
+
vec = vec / norm
|
| 40 |
+
return vec
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def get_embeddings_batch(texts: List[str], model: Optional[str] = None) -> List[np.ndarray]:
|
| 44 |
+
"""Generate embeddings for a batch of texts."""
|
| 45 |
+
model = model or EMBEDDING_MODEL
|
| 46 |
+
provider = EMBEDDING_PROVIDER.lower()
|
| 47 |
+
|
| 48 |
+
if provider == "siliconflow":
|
| 49 |
+
return await _siliconflow_embeddings(texts, model)
|
| 50 |
+
elif provider == "ollama":
|
| 51 |
+
return await _ollama_embeddings(texts, model)
|
| 52 |
+
elif provider == "openai":
|
| 53 |
+
return await _openai_embeddings(texts, model)
|
| 54 |
+
else:
|
| 55 |
+
return [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
async def _siliconflow_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
|
| 59 |
+
"""Call SiliconFlow (硅基流动) embedding API.
|
| 60 |
+
API docs: https://docs.siliconflow.cn/api-reference/embeddings
|
| 61 |
+
Compatible with OpenAI format, supports batch input."""
|
| 62 |
+
api_url = EMBEDDING_API_URL or "https://api.siliconflow.cn/v1/embeddings"
|
| 63 |
+
api_key = SILICONFLOW_API_KEY
|
| 64 |
+
if not api_key:
|
| 65 |
+
print("[WARN] SILICONFLOW_API_KEY not set, falling back to pseudo embeddings")
|
| 66 |
+
return [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 67 |
+
|
| 68 |
+
results = []
|
| 69 |
+
try:
|
| 70 |
+
async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
|
| 71 |
+
# SiliconFlow supports batch, but limit to 64 per request
|
| 72 |
+
for i in range(0, len(texts), 64):
|
| 73 |
+
batch = texts[i : i + 64]
|
| 74 |
+
resp = await client.post(
|
| 75 |
+
api_url,
|
| 76 |
+
headers={
|
| 77 |
+
"Authorization": f"Bearer {api_key}",
|
| 78 |
+
"Content-Type": "application/json",
|
| 79 |
+
},
|
| 80 |
+
json={"model": model, "input": batch, "encoding_format": "float"},
|
| 81 |
+
)
|
| 82 |
+
if resp.status_code == 200:
|
| 83 |
+
data = resp.json()
|
| 84 |
+
for item in sorted(data["data"], key=lambda x: x["index"]):
|
| 85 |
+
vec = np.array(item["embedding"], dtype=np.float32)
|
| 86 |
+
results.append(vec)
|
| 87 |
+
else:
|
| 88 |
+
print(f"[ERROR] SiliconFlow API returned {resp.status_code}: {resp.text[:200]}")
|
| 89 |
+
results.extend([_build_simple_embedding(t, EMBEDDING_DIM) for t in batch])
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[ERROR] SiliconFlow API call failed: {e}")
|
| 92 |
+
results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 93 |
+
return results
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
async def _ollama_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
|
| 97 |
+
"""Call Ollama embedding API."""
|
| 98 |
+
results = []
|
| 99 |
+
try:
|
| 100 |
+
async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
|
| 101 |
+
for text in texts:
|
| 102 |
+
resp = await client.post(
|
| 103 |
+
EMBEDDING_API_URL,
|
| 104 |
+
json={"model": model, "input": text}
|
| 105 |
+
)
|
| 106 |
+
if resp.status_code == 200:
|
| 107 |
+
data = resp.json()
|
| 108 |
+
if "embeddings" in data:
|
| 109 |
+
vec = np.array(data["embeddings"][0], dtype=np.float32)
|
| 110 |
+
elif "embedding" in data:
|
| 111 |
+
vec = np.array(data["embedding"], dtype=np.float32)
|
| 112 |
+
else:
|
| 113 |
+
vec = _build_simple_embedding(text, EMBEDDING_DIM)
|
| 114 |
+
results.append(vec)
|
| 115 |
+
else:
|
| 116 |
+
results.append(_build_simple_embedding(text, EMBEDDING_DIM))
|
| 117 |
+
except Exception:
|
| 118 |
+
results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 119 |
+
return results
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
async def _openai_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
|
| 123 |
+
"""Call OpenAI-compatible embedding API (e.g., vLLM)."""
|
| 124 |
+
api_url = os.environ.get("OPENAI_API_BASE", "http://localhost:8000") + "/v1/embeddings"
|
| 125 |
+
api_key = os.environ.get("OPENAI_API_KEY", "no-key")
|
| 126 |
+
results = []
|
| 127 |
+
try:
|
| 128 |
+
async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
|
| 129 |
+
resp = await client.post(
|
| 130 |
+
api_url,
|
| 131 |
+
headers={"Authorization": f"Bearer {api_key}"},
|
| 132 |
+
json={"model": model, "input": texts}
|
| 133 |
+
)
|
| 134 |
+
if resp.status_code == 200:
|
| 135 |
+
data = resp.json()
|
| 136 |
+
for item in data["data"]:
|
| 137 |
+
vec = np.array(item["embedding"], dtype=np.float32)
|
| 138 |
+
results.append(vec)
|
| 139 |
+
else:
|
| 140 |
+
results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 141 |
+
except Exception:
|
| 142 |
+
results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
|
| 143 |
+
return results
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def rerank_candidates(
|
| 147 |
+
query: str,
|
| 148 |
+
documents: List[str],
|
| 149 |
+
top_n: Optional[int] = None,
|
| 150 |
+
model: Optional[str] = None,
|
| 151 |
+
) -> List[Dict]:
|
| 152 |
+
"""Call SiliconFlow Reranker API (Qwen/Qwen3-VL-Reranker-8B).
|
| 153 |
+
Returns list of {"index": int, "relevance_score": float} sorted by score desc."""
|
| 154 |
+
model = model or RERANKER_MODEL
|
| 155 |
+
api_key = SILICONFLOW_API_KEY
|
| 156 |
+
|
| 157 |
+
if not api_key or not RERANKER_ENABLED:
|
| 158 |
+
return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
|
| 159 |
+
|
| 160 |
+
if not documents:
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
top_n = top_n or len(documents)
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
|
| 167 |
+
resp = await client.post(
|
| 168 |
+
RERANKER_API_URL,
|
| 169 |
+
headers={
|
| 170 |
+
"Authorization": f"Bearer {api_key}",
|
| 171 |
+
"Content-Type": "application/json",
|
| 172 |
+
},
|
| 173 |
+
json={
|
| 174 |
+
"model": model,
|
| 175 |
+
"query": query,
|
| 176 |
+
"documents": documents,
|
| 177 |
+
"top_n": top_n,
|
| 178 |
+
"return_documents": False,
|
| 179 |
+
},
|
| 180 |
+
)
|
| 181 |
+
if resp.status_code == 200:
|
| 182 |
+
data = resp.json()
|
| 183 |
+
results = data.get("results", [])
|
| 184 |
+
return sorted(results, key=lambda x: x["relevance_score"], reverse=True)
|
| 185 |
+
else:
|
| 186 |
+
print(f"[ERROR] Reranker API returned {resp.status_code}: {resp.text[:200]}")
|
| 187 |
+
return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"[ERROR] Reranker API call failed: {e}")
|
| 190 |
+
return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
| 194 |
+
"""Compute cosine similarity between two vectors."""
|
| 195 |
+
norm_a = np.linalg.norm(a)
|
| 196 |
+
norm_b = np.linalg.norm(b)
|
| 197 |
+
if norm_a == 0 or norm_b == 0:
|
| 198 |
+
return 0.0
|
| 199 |
+
return float(np.dot(a, b) / (norm_a * norm_b))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def batch_cosine_similarity(source_vecs: np.ndarray, target_vecs: np.ndarray) -> np.ndarray:
|
| 203 |
+
"""Compute pairwise cosine similarity matrix.
|
| 204 |
+
source_vecs: (M, D), target_vecs: (N, D)
|
| 205 |
+
Returns: (M, N) similarity matrix"""
|
| 206 |
+
source_norms = np.linalg.norm(source_vecs, axis=1, keepdims=True)
|
| 207 |
+
target_norms = np.linalg.norm(target_vecs, axis=1, keepdims=True)
|
| 208 |
+
source_norms = np.where(source_norms == 0, 1, source_norms)
|
| 209 |
+
target_norms = np.where(target_norms == 0, 1, target_norms)
|
| 210 |
+
source_normed = source_vecs / source_norms
|
| 211 |
+
target_normed = target_vecs / target_norms
|
| 212 |
+
return source_normed @ target_normed.T
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def embedding_to_bytes(vec: np.ndarray) -> bytes:
|
| 216 |
+
return vec.astype(np.float32).tobytes()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def bytes_to_embedding(data: bytes) -> np.ndarray:
|
| 220 |
+
return np.frombuffer(data, dtype=np.float32)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_match_level(score: float) -> str:
|
| 224 |
+
if score >= 0.90:
|
| 225 |
+
return "high"
|
| 226 |
+
elif score >= 0.80:
|
| 227 |
+
return "possible"
|
| 228 |
+
elif score >= 0.70:
|
| 229 |
+
return "low_confidence"
|
| 230 |
+
else:
|
| 231 |
+
return "no_match"
|
hf-vector-match-api/services/excel_service.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import hashlib
|
| 3 |
+
import json
|
| 4 |
+
import openpyxl
|
| 5 |
+
from typing import List, Dict, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "uploads")
|
| 9 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def save_upload_file(file_bytes: bytes, filename: str) -> str:
|
| 13 |
+
filepath = os.path.join(UPLOAD_DIR, filename)
|
| 14 |
+
with open(filepath, "wb") as f:
|
| 15 |
+
f.write(file_bytes)
|
| 16 |
+
return filepath
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
EXCLUDED_FIELDS = {"序号", "行号", "编号", "id", "ID", "Id", "no", "No", "NO", "行", "#"}
|
| 20 |
+
|
| 21 |
+
def get_sheet_info(filepath: str) -> Dict:
|
| 22 |
+
wb = openpyxl.load_workbook(filepath, read_only=True)
|
| 23 |
+
result = {"sheet_names": wb.sheetnames, "columns": {}, "all_columns": {}}
|
| 24 |
+
for sheet_name in wb.sheetnames:
|
| 25 |
+
ws = wb[sheet_name]
|
| 26 |
+
headers = []
|
| 27 |
+
for row in ws.iter_rows(min_row=1, max_row=1, values_only=True):
|
| 28 |
+
headers = [str(c) if c else f"列{i+1}" for i, c in enumerate(row)]
|
| 29 |
+
result["all_columns"][sheet_name] = headers
|
| 30 |
+
result["columns"][sheet_name] = [
|
| 31 |
+
h for h in headers if h.strip() not in EXCLUDED_FIELDS
|
| 32 |
+
]
|
| 33 |
+
wb.close()
|
| 34 |
+
return result
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def parse_excel_rows(
|
| 38 |
+
filepath: str,
|
| 39 |
+
sheet_name: str,
|
| 40 |
+
vector_fields: List[str],
|
| 41 |
+
) -> List[Dict]:
|
| 42 |
+
wb = openpyxl.load_workbook(filepath, read_only=True)
|
| 43 |
+
ws = wb[sheet_name]
|
| 44 |
+
rows_data = []
|
| 45 |
+
headers = []
|
| 46 |
+
for row_idx, row in enumerate(ws.iter_rows(values_only=True)):
|
| 47 |
+
if row_idx == 0:
|
| 48 |
+
headers = [str(c) if c else f"列{i+1}" for i, c in enumerate(row)]
|
| 49 |
+
continue
|
| 50 |
+
row_dict = {}
|
| 51 |
+
for i, val in enumerate(row):
|
| 52 |
+
if i < len(headers):
|
| 53 |
+
row_dict[headers[i]] = str(val) if val is not None else ""
|
| 54 |
+
|
| 55 |
+
text_parts = []
|
| 56 |
+
for field in vector_fields:
|
| 57 |
+
if field in row_dict and row_dict[field]:
|
| 58 |
+
text_parts.append(row_dict[field])
|
| 59 |
+
raw_text = " ".join(text_parts)
|
| 60 |
+
|
| 61 |
+
if not raw_text.strip():
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
text_hash = hashlib.sha256(raw_text.encode("utf-8")).hexdigest()
|
| 65 |
+
|
| 66 |
+
rows_data.append({
|
| 67 |
+
"row_number": row_idx + 1,
|
| 68 |
+
"raw_text": raw_text,
|
| 69 |
+
"text_hash": text_hash,
|
| 70 |
+
"field_values": json.dumps(row_dict, ensure_ascii=False),
|
| 71 |
+
})
|
| 72 |
+
wb.close()
|
| 73 |
+
return rows_data
|
hf-vector-match-api/services/match_service.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import List, Dict, Optional
|
| 4 |
+
from sqlalchemy.orm import Session
|
| 5 |
+
|
| 6 |
+
from models import (
|
| 7 |
+
VectorMatchTask, VectorDataset, VectorDataRow,
|
| 8 |
+
VectorEmbedding, MatchResult
|
| 9 |
+
)
|
| 10 |
+
from services.embedding_service import (
|
| 11 |
+
get_embeddings_batch, batch_cosine_similarity,
|
| 12 |
+
embedding_to_bytes, bytes_to_embedding, get_match_level,
|
| 13 |
+
rerank_candidates, RERANKER_ENABLED
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
BATCH_SIZE = 32
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _safe_commit(db):
|
| 21 |
+
"""提交事务,连接断开时自动回滚并重试"""
|
| 22 |
+
try:
|
| 23 |
+
db.commit()
|
| 24 |
+
except Exception:
|
| 25 |
+
db.rollback()
|
| 26 |
+
try:
|
| 27 |
+
db.commit()
|
| 28 |
+
except Exception:
|
| 29 |
+
db.rollback()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def run_match_task(task_id: int, db_factory):
|
| 33 |
+
"""Main matching pipeline: parse → vectorize → match → save results."""
|
| 34 |
+
db: Session = db_factory()
|
| 35 |
+
try:
|
| 36 |
+
task = db.query(VectorMatchTask).get(task_id)
|
| 37 |
+
if not task:
|
| 38 |
+
return
|
| 39 |
+
task.status = "running"
|
| 40 |
+
_safe_commit(db)
|
| 41 |
+
|
| 42 |
+
# Step 1: Parse source
|
| 43 |
+
task.progress_parse_source = 100
|
| 44 |
+
_safe_commit(db)
|
| 45 |
+
|
| 46 |
+
# Step 2: Parse target
|
| 47 |
+
task.progress_parse_target = 100
|
| 48 |
+
_safe_commit(db)
|
| 49 |
+
|
| 50 |
+
# Step 3: Vectorize
|
| 51 |
+
source_rows = (
|
| 52 |
+
db.query(VectorDataRow)
|
| 53 |
+
.filter(VectorDataRow.dataset_id == task.source_dataset_id)
|
| 54 |
+
.all()
|
| 55 |
+
)
|
| 56 |
+
target_rows = (
|
| 57 |
+
db.query(VectorDataRow)
|
| 58 |
+
.filter(VectorDataRow.dataset_id == task.target_dataset_id)
|
| 59 |
+
.all()
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
task.source_row_count = len(source_rows)
|
| 63 |
+
task.target_row_count = len(target_rows)
|
| 64 |
+
_safe_commit(db)
|
| 65 |
+
|
| 66 |
+
all_rows = source_rows + target_rows
|
| 67 |
+
reused = 0
|
| 68 |
+
new_count = 0
|
| 69 |
+
|
| 70 |
+
for i in range(0, len(all_rows), BATCH_SIZE):
|
| 71 |
+
batch = all_rows[i : i + BATCH_SIZE]
|
| 72 |
+
texts_to_embed = []
|
| 73 |
+
rows_to_embed = []
|
| 74 |
+
|
| 75 |
+
for row in batch:
|
| 76 |
+
existing = (
|
| 77 |
+
db.query(VectorEmbedding)
|
| 78 |
+
.filter(VectorEmbedding.text_hash == row.text_hash)
|
| 79 |
+
.first()
|
| 80 |
+
)
|
| 81 |
+
if existing and existing.data_row_id != row.id:
|
| 82 |
+
new_emb = VectorEmbedding(
|
| 83 |
+
data_row_id=row.id,
|
| 84 |
+
text_hash=row.text_hash,
|
| 85 |
+
embedding=existing.embedding,
|
| 86 |
+
model_name=existing.model_name,
|
| 87 |
+
dimension=existing.dimension,
|
| 88 |
+
)
|
| 89 |
+
db.add(new_emb)
|
| 90 |
+
reused += 1
|
| 91 |
+
elif existing:
|
| 92 |
+
reused += 1
|
| 93 |
+
else:
|
| 94 |
+
texts_to_embed.append(row.raw_text)
|
| 95 |
+
rows_to_embed.append(row)
|
| 96 |
+
|
| 97 |
+
if texts_to_embed:
|
| 98 |
+
embeddings = await get_embeddings_batch(texts_to_embed)
|
| 99 |
+
for row, vec in zip(rows_to_embed, embeddings):
|
| 100 |
+
emb = VectorEmbedding(
|
| 101 |
+
data_row_id=row.id,
|
| 102 |
+
text_hash=row.text_hash,
|
| 103 |
+
embedding=embedding_to_bytes(vec),
|
| 104 |
+
model_name="default",
|
| 105 |
+
dimension=len(vec),
|
| 106 |
+
)
|
| 107 |
+
db.add(emb)
|
| 108 |
+
new_count += 1
|
| 109 |
+
|
| 110 |
+
progress = min(100, int((i + len(batch)) / max(len(all_rows), 1) * 100))
|
| 111 |
+
task.progress_vectorize = progress
|
| 112 |
+
task.reused_vectors = reused
|
| 113 |
+
task.new_vectors = new_count
|
| 114 |
+
_safe_commit(db)
|
| 115 |
+
|
| 116 |
+
task.progress_vectorize = 100
|
| 117 |
+
_safe_commit(db)
|
| 118 |
+
|
| 119 |
+
# Step 4: Load candidate range
|
| 120 |
+
task.progress_load_candidates = 100
|
| 121 |
+
_safe_commit(db)
|
| 122 |
+
|
| 123 |
+
# Step 5: Similarity calculation
|
| 124 |
+
source_embeddings = []
|
| 125 |
+
source_row_ids = []
|
| 126 |
+
for row in source_rows:
|
| 127 |
+
emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
|
| 128 |
+
if emb:
|
| 129 |
+
source_embeddings.append(bytes_to_embedding(emb.embedding))
|
| 130 |
+
source_row_ids.append(row.id)
|
| 131 |
+
|
| 132 |
+
target_embeddings = []
|
| 133 |
+
target_row_ids = []
|
| 134 |
+
for row in target_rows:
|
| 135 |
+
emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
|
| 136 |
+
if emb:
|
| 137 |
+
target_embeddings.append(bytes_to_embedding(emb.embedding))
|
| 138 |
+
target_row_ids.append(row.id)
|
| 139 |
+
|
| 140 |
+
if not source_embeddings or not target_embeddings:
|
| 141 |
+
task.status = "completed"
|
| 142 |
+
task.progress_similarity = 100
|
| 143 |
+
task.progress_save_results = 100
|
| 144 |
+
_safe_commit(db)
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
source_matrix = np.stack(source_embeddings)
|
| 148 |
+
target_matrix = np.stack(target_embeddings)
|
| 149 |
+
|
| 150 |
+
sim_matrix = batch_cosine_similarity(source_matrix, target_matrix)
|
| 151 |
+
task.progress_similarity = 100
|
| 152 |
+
_safe_commit(db)
|
| 153 |
+
|
| 154 |
+
# Step 6: Collect Top-K candidates per source row
|
| 155 |
+
# top_k 为初始候选数,rerank_top_k 为重排序后保留数
|
| 156 |
+
initial_k = task.top_k
|
| 157 |
+
initial_k = min(initial_k, len(target_row_ids))
|
| 158 |
+
|
| 159 |
+
# Build raw_text lookup for reranker
|
| 160 |
+
source_text_map = {}
|
| 161 |
+
target_text_map = {}
|
| 162 |
+
if RERANKER_ENABLED:
|
| 163 |
+
for row in source_rows:
|
| 164 |
+
source_text_map[row.id] = row.raw_text
|
| 165 |
+
for row in target_rows:
|
| 166 |
+
target_text_map[row.id] = row.raw_text
|
| 167 |
+
|
| 168 |
+
high_count = 0
|
| 169 |
+
low_count = 0
|
| 170 |
+
total_source = len(source_row_ids)
|
| 171 |
+
|
| 172 |
+
for idx, src_id in enumerate(source_row_ids):
|
| 173 |
+
scores = sim_matrix[idx]
|
| 174 |
+
top_indices = np.argsort(scores)[::-1][:initial_k]
|
| 175 |
+
|
| 176 |
+
candidates = []
|
| 177 |
+
for tgt_idx in top_indices:
|
| 178 |
+
candidates.append({
|
| 179 |
+
"tgt_idx": tgt_idx,
|
| 180 |
+
"tgt_row_id": target_row_ids[tgt_idx],
|
| 181 |
+
"sim_score": float(scores[tgt_idx]),
|
| 182 |
+
"rerank_score": None,
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
# Step 6.5: Rerank candidates
|
| 186 |
+
if RERANKER_ENABLED and candidates:
|
| 187 |
+
query_text = source_text_map.get(src_id, "")
|
| 188 |
+
doc_texts = [target_text_map.get(c["tgt_row_id"], "") for c in candidates]
|
| 189 |
+
try:
|
| 190 |
+
rerank_top_k = task.rerank_top_k or task.top_k
|
| 191 |
+
rerank_results = await rerank_candidates(
|
| 192 |
+
query=query_text,
|
| 193 |
+
documents=doc_texts,
|
| 194 |
+
top_n=rerank_top_k,
|
| 195 |
+
)
|
| 196 |
+
# Map rerank scores back to candidates
|
| 197 |
+
for rr in rerank_results:
|
| 198 |
+
orig_idx = rr["index"]
|
| 199 |
+
if orig_idx < len(candidates):
|
| 200 |
+
candidates[orig_idx]["rerank_score"] = rr["relevance_score"]
|
| 201 |
+
|
| 202 |
+
# Sort by rerank_score (desc), keep rerank_top_k
|
| 203 |
+
candidates.sort(
|
| 204 |
+
key=lambda c: c["rerank_score"] if c["rerank_score"] is not None else -1,
|
| 205 |
+
reverse=True,
|
| 206 |
+
)
|
| 207 |
+
candidates = candidates[:rerank_top_k]
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f"[WARN] Rerank failed for source {src_id}: {e}")
|
| 210 |
+
candidates = candidates[:task.top_k]
|
| 211 |
+
|
| 212 |
+
progress = min(100, int((idx + 1) / total_source * 100))
|
| 213 |
+
task.progress_rerank = progress
|
| 214 |
+
if idx % 20 == 0:
|
| 215 |
+
_safe_commit(db)
|
| 216 |
+
else:
|
| 217 |
+
candidates = candidates[:task.top_k]
|
| 218 |
+
|
| 219 |
+
# Save results
|
| 220 |
+
for rank, c in enumerate(candidates):
|
| 221 |
+
level = get_match_level(c["sim_score"])
|
| 222 |
+
result = MatchResult(
|
| 223 |
+
task_id=task.id,
|
| 224 |
+
source_row_id=src_id,
|
| 225 |
+
target_row_id=c["tgt_row_id"],
|
| 226 |
+
similarity_score=c["sim_score"],
|
| 227 |
+
rerank_score=c["rerank_score"],
|
| 228 |
+
rank=rank + 1,
|
| 229 |
+
rerank_rank=rank + 1 if c["rerank_score"] is not None else None,
|
| 230 |
+
candidate_scope=task.candidate_scope,
|
| 231 |
+
match_level=level,
|
| 232 |
+
)
|
| 233 |
+
db.add(result)
|
| 234 |
+
|
| 235 |
+
if rank == 0:
|
| 236 |
+
if c["sim_score"] >= 0.90:
|
| 237 |
+
high_count += 1
|
| 238 |
+
elif c["sim_score"] < 0.70:
|
| 239 |
+
low_count += 1
|
| 240 |
+
|
| 241 |
+
progress = min(100, int((idx + 1) / total_source * 100))
|
| 242 |
+
task.progress_save_results = progress
|
| 243 |
+
if idx % 50 == 0:
|
| 244 |
+
_safe_commit(db)
|
| 245 |
+
|
| 246 |
+
task.high_match_count = high_count
|
| 247 |
+
task.low_confidence_count = low_count
|
| 248 |
+
task.progress_rerank = 100
|
| 249 |
+
task.progress_save_results = 100
|
| 250 |
+
task.status = "completed"
|
| 251 |
+
_safe_commit(db)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
task = db.query(VectorMatchTask).get(task_id)
|
| 255 |
+
if task:
|
| 256 |
+
task.status = "failed"
|
| 257 |
+
_safe_commit(db)
|
| 258 |
+
raise e
|
| 259 |
+
finally:
|
| 260 |
+
db.close()
|