teryryy commited on
Commit
010f0b1
·
verified ·
1 Parent(s): 37bec6e

Upload 13 files

Browse files
hf-vector-match-api/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
hf-vector-match-api/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ .env.local
5
+ data/uploads/
hf-vector-match-api/Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user requirements.txt .
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+
14
+ RUN mkdir -p data/uploads
15
+
16
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
hf-vector-match-api/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vector Match Api
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
hf-vector-match-api/database.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from urllib.parse import quote_plus
4
+
5
+ from dotenv import load_dotenv
6
+ from sqlalchemy import create_engine
7
+
8
+ _env_dir = Path(__file__).resolve().parent
9
+ load_dotenv(_env_dir / ".env.local", override=False) # 本地开发优先
10
+ load_dotenv(_env_dir / ".env", override=False) # 兜底(Docker/线上)
11
+ from sqlalchemy.orm import sessionmaker, declarative_base
12
+
13
+ # 可选:整串 URL(优先级最高),例如 postgresql+psycopg2://user:pass@host:5432/dbname
14
+ _database_url = os.environ.get("DATABASE_URL", "").strip()
15
+
16
+ if _database_url:
17
+ SQLALCHEMY_DATABASE_URL = _database_url
18
+ else:
19
+ PG_HOST = os.environ.get("PG_HOST", "localhost")
20
+ PG_PORT = os.environ.get("PG_PORT", "5432")
21
+ PG_USER = os.environ.get("PG_USER", "postgres")
22
+ PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres")
23
+ PG_DB = os.environ.get("PG_DB", "vector_match")
24
+ _pw = quote_plus(PG_PASSWORD)
25
+ SQLALCHEMY_DATABASE_URL = (
26
+ f"postgresql+psycopg2://{PG_USER}:{_pw}@{PG_HOST}:{PG_PORT}/{PG_DB}"
27
+ )
28
+
29
+ PG_SCHEMA = os.environ.get("PG_SCHEMA", "vector_match")
30
+
31
+ engine = create_engine(
32
+ SQLALCHEMY_DATABASE_URL,
33
+ pool_pre_ping=True,
34
+ pool_size=20,
35
+ pool_recycle=180,
36
+ pool_timeout=60,
37
+ max_overflow=10,
38
+ )
39
+
40
+ # 每次连接自动切换到 vector_match schema
41
+ from sqlalchemy import event
42
+
43
+ @event.listens_for(engine, "connect")
44
+ def _set_search_path(dbapi_conn, connection_record):
45
+ cursor = dbapi_conn.cursor()
46
+ cursor.execute(f"SET search_path TO {PG_SCHEMA}, public")
47
+ cursor.close()
48
+
49
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
50
+ Base = declarative_base()
51
+
52
+ DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
53
+ os.makedirs(DATA_DIR, exist_ok=True)
54
+
55
+ _SCHEMA_TABLES = (
56
+ "vector_match_task",
57
+ "vector_dataset",
58
+ "vector_data_row",
59
+ "vector_embedding",
60
+ "match_result",
61
+ )
62
+
63
+
64
+ def get_db():
65
+ db = SessionLocal()
66
+ try:
67
+ yield db
68
+ except Exception:
69
+ db.rollback()
70
+ raise
71
+ finally:
72
+ db.close()
73
+
74
+
75
+ def _table_sql_name(table: str) -> str:
76
+ if engine.dialect.name == "postgresql":
77
+ return f'"{PG_SCHEMA}"."{table}"'
78
+ return table
79
+
80
+
81
+ def _column_names_conn(conn, table: str) -> set:
82
+ """
83
+ 与当前连接共用同一事务,避免在持有 ALTER 锁的事务内再用 inspect(engine) 开新连接查目录,
84
+ 否则 PostgreSQL 上会自锁(会话 A 持锁等 B 查元数据,B 等 A 释放锁)。
85
+ """
86
+ from sqlalchemy import inspect
87
+
88
+ insp = inspect(conn)
89
+ schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
90
+ return {c["name"] for c in insp.get_columns(table, schema=schema)}
91
+
92
+
93
+ def _ensure_is_archived_column():
94
+ """旧库无 is_archived 时补列。"""
95
+ from sqlalchemy import inspect, text
96
+
97
+ insp = inspect(engine)
98
+ schema = PG_SCHEMA if engine.dialect.name == "postgresql" else None
99
+ try:
100
+ cols = insp.get_columns("vector_match_task", schema=schema)
101
+ except Exception:
102
+ return
103
+ if any(c["name"] == "is_archived" for c in cols):
104
+ return
105
+ ft = _table_sql_name("vector_match_task")
106
+ ddl = (
107
+ f"ALTER TABLE {ft} ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
108
+ if engine.dialect.name == "postgresql"
109
+ else "ALTER TABLE vector_match_task ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0"
110
+ )
111
+ with engine.begin() as conn:
112
+ conn.execute(text(ddl))
113
+
114
+
115
+ def _ensure_time_is_delete_columns():
116
+ """
117
+ 统一:created_at→created_time;任务表 updated_at→updated_time;
118
+ 各表补 is_delete;is_deleted→is_delete;遗留 deleted_at 迁移后删除。
119
+ """
120
+ from sqlalchemy import text
121
+
122
+ ft_task = _table_sql_name("vector_match_task")
123
+
124
+ with engine.begin() as conn:
125
+ cols = _column_names_conn(conn, "vector_match_task")
126
+ if "created_at" in cols and "created_time" not in cols:
127
+ conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN created_at TO created_time"))
128
+ if "updated_at" in cols and "updated_time" not in cols:
129
+ conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN updated_at TO updated_time"))
130
+ if "is_deleted" in cols and "is_delete" not in cols:
131
+ conn.execute(text(f"ALTER TABLE {ft_task} RENAME COLUMN is_deleted TO is_delete"))
132
+ cols = _column_names_conn(conn, "vector_match_task")
133
+ if "deleted_at" in cols:
134
+ if "is_delete" not in cols:
135
+ conn.execute(
136
+ text(f"ALTER TABLE {ft_task} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
137
+ )
138
+ conn.execute(text(f"UPDATE {ft_task} SET is_delete = 1 WHERE deleted_at IS NOT NULL"))
139
+ conn.execute(text(f"ALTER TABLE {ft_task} DROP COLUMN deleted_at"))
140
+
141
+ for table in _SCHEMA_TABLES:
142
+ ft = _table_sql_name(table)
143
+ with engine.begin() as conn:
144
+ cols = _column_names_conn(conn, table)
145
+ if "created_at" in cols and "created_time" not in cols:
146
+ conn.execute(text(f"ALTER TABLE {ft} RENAME COLUMN created_at TO created_time"))
147
+ cols = _column_names_conn(conn, table)
148
+ if "is_delete" not in cols:
149
+ conn.execute(
150
+ text(f"ALTER TABLE {ft} ADD COLUMN is_delete INTEGER NOT NULL DEFAULT 0")
151
+ )
152
+ cols = _column_names_conn(conn, table)
153
+ if "updated_time" not in cols:
154
+ conn.execute(text(f"ALTER TABLE {ft} ADD COLUMN updated_time TIMESTAMP NULL"))
155
+ if "created_time" in cols:
156
+ conn.execute(
157
+ text(
158
+ f"UPDATE {ft} SET updated_time = created_time "
159
+ f"WHERE updated_time IS NULL"
160
+ )
161
+ )
162
+
163
+
164
+ def init_db():
165
+ from models import Base as ModelBase # noqa: F401
166
+ from sqlalchemy import text
167
+ with engine.connect() as conn:
168
+ conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {PG_SCHEMA}"))
169
+ conn.commit()
170
+ ModelBase.metadata.create_all(bind=engine)
171
+ _ensure_is_archived_column()
172
+ _ensure_time_is_delete_columns()
hf-vector-match-api/main.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import json
4
+ import time
5
+ import datetime
6
+ import httpx
7
+ from typing import List, Optional
8
+
9
+ from fastapi import FastAPI, UploadFile, File, Form, Depends, Query, HTTPException, BackgroundTasks
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from sqlalchemy.orm import Session
12
+
13
+ from database import get_db, init_db, SessionLocal
14
+ from models import (
15
+ VectorMatchTask, VectorDataset, VectorDataRow,
16
+ VectorEmbedding, MatchResult,
17
+ )
18
+ from schemas import (
19
+ TaskCreate, TaskDetail, TaskProgress, TaskListItem,
20
+ MatchResultItem, MatchResultPage, SourceWithCandidates, CandidateDetail,
21
+ UploadResponse, SettingItem, SettingsResponse, DatasetInfo,
22
+ )
23
+ from services.excel_service import save_upload_file, get_sheet_info, parse_excel_rows
24
+ from services.match_service import run_match_task
25
+
26
+
27
+ app = FastAPI(title="VectorMatch API", version="1.0.0")
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ import logging, traceback
38
+ from starlette.requests import Request
39
+ from starlette.responses import JSONResponse
40
+
41
+ logger = logging.getLogger("uvicorn.error")
42
+
43
+ @app.exception_handler(Exception)
44
+ async def global_exception_handler(request: Request, exc: Exception):
45
+ logger.error(f"Unhandled error on {request.method} {request.url}:\n{traceback.format_exc()}")
46
+ return JSONResponse(status_code=500, content={"detail": str(exc)})
47
+
48
+
49
+ # ─── 健康状态缓存 ─────────────────────────────────────────────────────────
50
+ _health_cache = {
51
+ "result": {"embedding_ok": False, "reranker_ok": False, "embedding_model": "",
52
+ "reranker_model": "", "reranker_enabled": False, "has_api_key": False},
53
+ "updated_at": 0,
54
+ }
55
+ _HEALTH_TTL = 30 # 缓存有效期(秒)
56
+
57
+
58
+ async def _do_health_check():
59
+ """执行真正的 API 探活,更新缓存"""
60
+ import services.embedding_service as es
61
+ api_key = es.SILICONFLOW_API_KEY
62
+ result = {
63
+ "embedding_ok": False,
64
+ "reranker_ok": False,
65
+ "embedding_model": es.EMBEDDING_MODEL,
66
+ "reranker_model": es.RERANKER_MODEL,
67
+ "reranker_enabled": es.RERANKER_ENABLED,
68
+ "has_api_key": bool(api_key),
69
+ }
70
+ if api_key:
71
+ try:
72
+ async with httpx.AsyncClient(timeout=5.0, proxies={}) as client:
73
+ try:
74
+ emb_resp = await client.post(
75
+ "https://api.siliconflow.cn/v1/embeddings",
76
+ headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
77
+ json={"model": es.EMBEDDING_MODEL, "input": ["ping"]},
78
+ )
79
+ result["embedding_ok"] = emb_resp.status_code == 200
80
+ except Exception:
81
+ pass
82
+ if es.RERANKER_ENABLED:
83
+ try:
84
+ rerank_resp = await client.post(
85
+ "https://api.siliconflow.cn/v1/rerank",
86
+ headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
87
+ json={"model": es.RERANKER_MODEL, "query": "ping", "documents": ["pong"], "top_n": 1},
88
+ )
89
+ result["reranker_ok"] = rerank_resp.status_code == 200
90
+ except Exception:
91
+ pass
92
+ except Exception:
93
+ pass
94
+ _health_cache["result"] = result
95
+ _health_cache["updated_at"] = time.time()
96
+ return result
97
+
98
+
99
+ async def _health_polling_loop():
100
+ """后台定时探活循环"""
101
+ while True:
102
+ try:
103
+ await _do_health_check()
104
+ except Exception:
105
+ pass
106
+ await asyncio.sleep(_HEALTH_TTL)
107
+
108
+
109
+ @app.on_event("startup")
110
+ async def startup():
111
+ init_db()
112
+ # 启动后台健康检查循环
113
+ asyncio.create_task(_health_polling_loop())
114
+
115
+
116
+ # ─── Upload Excel ───────────────────────────────────────────────────────────
117
+ @app.post("/api/upload", response_model=UploadResponse)
118
+ async def upload_excel(
119
+ file: UploadFile = File(...),
120
+ dataset_role: str = Form("source"),
121
+ db: Session = Depends(get_db),
122
+ ):
123
+ content = await file.read()
124
+ filepath = save_upload_file(content, file.filename)
125
+ info = get_sheet_info(filepath)
126
+
127
+ dataset = VectorDataset(
128
+ name=file.filename,
129
+ file_name=file.filename,
130
+ dataset_role=dataset_role,
131
+ data_scope="task",
132
+ )
133
+ db.add(dataset)
134
+ db.commit()
135
+ db.refresh(dataset)
136
+
137
+ return UploadResponse(
138
+ dataset_id=dataset.id,
139
+ file_name=file.filename,
140
+ sheet_names=info["sheet_names"],
141
+ columns=info["columns"],
142
+ all_columns=info.get("all_columns", info["columns"]),
143
+ )
144
+
145
+
146
+ # ─── Configure dataset (sheet, fields) ─────────────────────────────────────
147
+ @app.post("/api/dataset/{dataset_id}/configure")
148
+ def configure_dataset(
149
+ dataset_id: int,
150
+ sheet_name: str = Form(...),
151
+ vector_fields: str = Form(...),
152
+ db: Session = Depends(get_db),
153
+ ):
154
+ dataset = db.query(VectorDataset).get(dataset_id)
155
+ if not dataset:
156
+ raise HTTPException(404, "Dataset not found")
157
+
158
+ dataset.sheet_name = sheet_name
159
+ dataset.vector_fields = vector_fields
160
+ db.commit()
161
+
162
+ fields = json.loads(vector_fields)
163
+ import os
164
+ filepath = os.path.join(
165
+ os.path.dirname(__file__), "data", "uploads", dataset.file_name
166
+ )
167
+ rows = parse_excel_rows(filepath, sheet_name, fields)
168
+
169
+ for row_data in rows:
170
+ dr = VectorDataRow(
171
+ dataset_id=dataset.id,
172
+ dataset_role=dataset.dataset_role,
173
+ data_scope=dataset.data_scope,
174
+ row_number=row_data["row_number"],
175
+ raw_text=row_data["raw_text"],
176
+ text_hash=row_data["text_hash"],
177
+ field_values=row_data["field_values"],
178
+ )
179
+ db.add(dr)
180
+ dataset.row_count = len(rows)
181
+ db.commit()
182
+
183
+ return {"status": "ok", "row_count": len(rows)}
184
+
185
+
186
+ # ─── Get dataset info ──────────────────────────────────────────────────────
187
+ @app.get("/api/dataset/{dataset_id}", response_model=DatasetInfo)
188
+ def get_dataset(dataset_id: int, db: Session = Depends(get_db)):
189
+ dataset = db.query(VectorDataset).get(dataset_id)
190
+ if not dataset:
191
+ raise HTTPException(404, "Dataset not found")
192
+ return dataset
193
+
194
+
195
+ # ─── Create & start task ───────────────────────────────────────────────────
196
+ @app.post("/api/task", response_model=TaskDetail)
197
+ async def create_task(
198
+ background_tasks: BackgroundTasks,
199
+ source_dataset_id: int = Form(...),
200
+ target_dataset_id: int = Form(...),
201
+ match_mode: str = Form("two_file"),
202
+ top_k: int = Form(10),
203
+ rerank_top_k: int = Form(3),
204
+ min_threshold: float = Form(0.70),
205
+ candidate_scope: str = Form("current_task_target"),
206
+ db: Session = Depends(get_db),
207
+ ):
208
+ now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
209
+ task_code = now.strftime("%Y%m%d%H%M%S") + f"{now.microsecond // 1000:03d}"
210
+
211
+ src = db.query(VectorDataset).get(source_dataset_id)
212
+ tgt = db.query(VectorDataset).get(target_dataset_id)
213
+ if not src or not tgt:
214
+ raise HTTPException(400, "Source or target dataset not found")
215
+
216
+ task = VectorMatchTask(
217
+ task_code=task_code,
218
+ match_mode=match_mode,
219
+ candidate_scope=candidate_scope,
220
+ source_dataset_id=source_dataset_id,
221
+ target_dataset_id=target_dataset_id,
222
+ top_k=top_k,
223
+ rerank_top_k=rerank_top_k,
224
+ min_threshold=min_threshold,
225
+ status="pending",
226
+ )
227
+ db.add(task)
228
+ db.commit()
229
+ db.refresh(task)
230
+
231
+ src.task_id = task.id
232
+ tgt.task_id = task.id
233
+ db.query(VectorDataRow).filter(VectorDataRow.dataset_id == src.id).update({"task_id": task.id})
234
+ db.query(VectorDataRow).filter(VectorDataRow.dataset_id == tgt.id).update({"task_id": task.id})
235
+ db.commit()
236
+
237
+ background_tasks.add_task(_run_task_in_background, task.id)
238
+
239
+ db.refresh(task)
240
+ return task
241
+
242
+
243
+ def _run_task_in_background(task_id: int):
244
+ loop = asyncio.new_event_loop()
245
+ asyncio.set_event_loop(loop)
246
+ loop.run_until_complete(run_match_task(task_id, SessionLocal))
247
+ loop.close()
248
+
249
+
250
+ def _get_alive_task(db: Session, task_id: int) -> Optional[VectorMatchTask]:
251
+ """未软删除的任务(is_delete=0)。"""
252
+ task = db.query(VectorMatchTask).get(task_id)
253
+ if not task or (task.is_delete or 0) == 1:
254
+ return None
255
+ return task
256
+
257
+
258
+ # ─── Task progress ─────────────────────────────────────────────────────────
259
+ @app.get("/api/task/{task_id}/progress", response_model=TaskProgress)
260
+ def get_task_progress(task_id: int, db: Session = Depends(get_db)):
261
+ task = _get_alive_task(db, task_id)
262
+ if not task:
263
+ raise HTTPException(404, "Task not found")
264
+ return task
265
+
266
+
267
+ # ─── Task detail ───────────────────────────────────────────────────────────
268
+ @app.get("/api/task/{task_id}", response_model=TaskDetail)
269
+ def get_task_detail(task_id: int, db: Session = Depends(get_db)):
270
+ task = _get_alive_task(db, task_id)
271
+ if not task:
272
+ raise HTTPException(404, "Task not found")
273
+ return task
274
+
275
+
276
+ # ─── Task list ─────────────────────────────────────────────────────────────
277
+ @app.get("/api/tasks", response_model=List[TaskListItem])
278
+ def list_tasks(
279
+ scope: str = Query("active", description="active=未归档, archived=仅归档, deleted=回收站"),
280
+ db: Session = Depends(get_db),
281
+ ):
282
+ if scope not in ("active", "archived", "deleted"):
283
+ raise HTTPException(400, "scope 须为 active、archived 或 deleted")
284
+ q = db.query(VectorMatchTask)
285
+ if scope == "deleted":
286
+ q = q.filter(VectorMatchTask.is_delete == 1)
287
+ else:
288
+ q = q.filter(VectorMatchTask.is_delete == 0)
289
+ if scope == "archived":
290
+ q = q.filter(VectorMatchTask.is_archived == 1)
291
+ else:
292
+ q = q.filter(VectorMatchTask.is_archived == 0)
293
+ tasks = q.order_by(VectorMatchTask.created_time.desc()).all()
294
+ result = []
295
+ for t in tasks:
296
+ src_name = t.source_dataset.name if t.source_dataset else None
297
+ tgt_name = t.target_dataset.name if t.target_dataset else None
298
+ result.append(TaskListItem(
299
+ id=t.id,
300
+ task_code=t.task_code,
301
+ match_mode=t.match_mode,
302
+ candidate_scope=t.candidate_scope,
303
+ source_dataset_name=src_name,
304
+ target_dataset_name=tgt_name,
305
+ status=t.status,
306
+ is_archived=t.is_archived or 0,
307
+ is_delete=t.is_delete or 0,
308
+ created_time=t.created_time,
309
+ ))
310
+ return result
311
+
312
+
313
+ @app.post("/api/task/{task_id}/archive")
314
+ def archive_task(task_id: int, db: Session = Depends(get_db)):
315
+ task = _get_alive_task(db, task_id)
316
+ if not task:
317
+ raise HTTPException(404, "Task not found")
318
+ task.is_archived = 1
319
+ db.commit()
320
+ return {"status": "ok"}
321
+
322
+
323
+ @app.post("/api/task/{task_id}/unarchive")
324
+ def unarchive_task(task_id: int, db: Session = Depends(get_db)):
325
+ task = _get_alive_task(db, task_id)
326
+ if not task:
327
+ raise HTTPException(404, "Task not found")
328
+ task.is_archived = 0
329
+ db.commit()
330
+ return {"status": "ok"}
331
+
332
+
333
+ @app.delete("/api/task/{task_id}")
334
+ def delete_task(task_id: int, db: Session = Depends(get_db)):
335
+ """软删除:is_delete=1,数据仍保留在库中。"""
336
+ task = _get_alive_task(db, task_id)
337
+ if not task:
338
+ raise HTTPException(404, "Task not found")
339
+ task.is_delete = 1
340
+ db.commit()
341
+ return {"status": "ok"}
342
+
343
+
344
+ @app.post("/api/task/{task_id}/restore")
345
+ def restore_task(task_id: int, db: Session = Depends(get_db)):
346
+ """从回收站恢复。"""
347
+ task = db.query(VectorMatchTask).get(task_id)
348
+ if not task or (task.is_delete or 0) != 1:
349
+ raise HTTPException(404, "Task not found or not deleted")
350
+ task.is_delete = 0
351
+ db.commit()
352
+ return {"status": "ok"}
353
+
354
+
355
+ # ─── Match results ─────────────────────────────────────────────────────────
356
+ @app.get("/api/task/{task_id}/results", response_model=MatchResultPage)
357
+ def get_task_results(
358
+ task_id: int,
359
+ page: int = Query(1, ge=1),
360
+ page_size: int = Query(20, ge=1, le=100),
361
+ search: Optional[str] = None,
362
+ level: Optional[str] = None,
363
+ sort: str = "score_desc",
364
+ db: Session = Depends(get_db),
365
+ ):
366
+ if not _get_alive_task(db, task_id):
367
+ raise HTTPException(404, "Task not found")
368
+ query = (
369
+ db.query(MatchResult)
370
+ .filter(MatchResult.task_id == task_id, MatchResult.rank == 1)
371
+ )
372
+ if level and level != "all":
373
+ query = query.filter(MatchResult.match_level == level)
374
+
375
+ if sort == "score_desc":
376
+ query = query.order_by(MatchResult.similarity_score.desc())
377
+ elif sort == "score_asc":
378
+ query = query.order_by(MatchResult.similarity_score.asc())
379
+ else:
380
+ query = query.order_by(MatchResult.source_row_id)
381
+
382
+ total = query.count()
383
+ results = query.offset((page - 1) * page_size).limit(page_size).all()
384
+
385
+ items = []
386
+ for r in results:
387
+ src_row = db.query(VectorDataRow).get(r.source_row_id)
388
+ tgt_row = db.query(VectorDataRow).get(r.target_row_id)
389
+ if search:
390
+ if search.lower() not in (src_row.raw_text or "").lower() and \
391
+ search.lower() not in (tgt_row.raw_text or "").lower():
392
+ continue
393
+ items.append(MatchResultItem(
394
+ id=r.id,
395
+ source_row_id=r.source_row_id,
396
+ source_row_number=src_row.row_number if src_row else 0,
397
+ source_text=src_row.raw_text if src_row else "",
398
+ target_text=tgt_row.raw_text if tgt_row else "",
399
+ similarity_score=r.similarity_score,
400
+ rerank_score=r.rerank_score,
401
+ match_level=r.match_level or "",
402
+ candidate_scope=r.candidate_scope,
403
+ is_confirmed=r.is_confirmed,
404
+ ))
405
+
406
+ return MatchResultPage(items=items, total=total, page=page, page_size=page_size)
407
+
408
+
409
+ # ─── Candidate details for a source row ────────────────────────────────────
410
+ @app.get("/api/task/{task_id}/candidates/{source_row_id}", response_model=SourceWithCandidates)
411
+ def get_candidates(task_id: int, source_row_id: int, db: Session = Depends(get_db)):
412
+ if not _get_alive_task(db, task_id):
413
+ raise HTTPException(404, "Task not found")
414
+ src_row = db.query(VectorDataRow).get(source_row_id)
415
+ if not src_row:
416
+ raise HTTPException(404, "Source row not found")
417
+
418
+ results = (
419
+ db.query(MatchResult)
420
+ .filter(MatchResult.task_id == task_id, MatchResult.source_row_id == source_row_id)
421
+ .order_by(MatchResult.rank)
422
+ .all()
423
+ )
424
+
425
+ candidates = []
426
+ for r in results:
427
+ tgt_row = db.query(VectorDataRow).get(r.target_row_id)
428
+ candidates.append(CandidateDetail(
429
+ rank=r.rank,
430
+ rerank_rank=r.rerank_rank,
431
+ target_row_id=r.target_row_id,
432
+ target_text=tgt_row.raw_text if tgt_row else "",
433
+ similarity_score=r.similarity_score,
434
+ rerank_score=r.rerank_score,
435
+ match_level=r.match_level or "",
436
+ dataset_role="target",
437
+ candidate_scope=r.candidate_scope,
438
+ data_row_id=tgt_row.id if tgt_row else 0,
439
+ is_confirmed=r.is_confirmed,
440
+ ))
441
+
442
+ return SourceWithCandidates(
443
+ source_row_id=src_row.id,
444
+ source_text=src_row.raw_text,
445
+ source_row_number=src_row.row_number,
446
+ dataset_role=src_row.dataset_role,
447
+ data_row_id=src_row.id,
448
+ candidates=candidates,
449
+ )
450
+
451
+
452
+ # ─── Confirm match ─────────────────────────────────────────────────────────
453
+ @app.post("/api/result/{result_id}/confirm")
454
+ def confirm_match(result_id: int, db: Session = Depends(get_db)):
455
+ result = db.query(MatchResult).get(result_id)
456
+ if not result:
457
+ raise HTTPException(404, "Result not found")
458
+ result.is_confirmed = 1
459
+ db.commit()
460
+ return {"status": "ok"}
461
+
462
+
463
+ @app.post("/api/result/{result_id}/ignore")
464
+ def ignore_match(result_id: int, db: Session = Depends(get_db)):
465
+ result = db.query(MatchResult).get(result_id)
466
+ if not result:
467
+ raise HTTPException(404, "Result not found")
468
+ result.is_confirmed = -1
469
+ db.commit()
470
+ return {"status": "ok"}
471
+
472
+
473
+ # ─── Settings (read/write .env) ────────────────────────────────────────────
474
+ _backend_dir = os.path.dirname(os.path.abspath(__file__))
475
+ _env_local = os.path.join(_backend_dir, ".env.local")
476
+ ENV_PATH = _env_local if os.path.exists(_env_local) else os.path.join(_backend_dir, ".env")
477
+
478
+ def _read_env() -> dict:
479
+ result = {}
480
+ if os.path.exists(ENV_PATH):
481
+ with open(ENV_PATH, "r", encoding="utf-8") as f:
482
+ for line in f:
483
+ line = line.strip()
484
+ if line and not line.startswith("#") and "=" in line:
485
+ k, v = line.split("=", 1)
486
+ result[k.strip()] = v.strip()
487
+ return result
488
+
489
+ def _write_env(settings: dict):
490
+ with open(ENV_PATH, "w", encoding="utf-8") as f:
491
+ for k, v in settings.items():
492
+ f.write(f"{k}={v}\n")
493
+
494
+ @app.get("/api/settings", response_model=SettingsResponse)
495
+ def get_settings():
496
+ return SettingsResponse(settings=_read_env())
497
+
498
+ @app.post("/api/settings")
499
+ async def update_settings(items: List[SettingItem]):
500
+ current = _read_env()
501
+ for item in items:
502
+ current[item.key] = item.value
503
+ _write_env(current)
504
+ # 保存后自动重载环境变量,无需手动重启
505
+ from dotenv import load_dotenv
506
+ load_dotenv(ENV_PATH, override=True)
507
+ # 同步更新 embedding_service 模块中的配置常量
508
+ import services.embedding_service as es
509
+ es.SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "")
510
+ es.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-m3")
511
+ es.EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "1024"))
512
+ es.RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "Qwen/Qwen3-VL-Reranker-8B")
513
+ es.RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true"
514
+ # 立即刷新健康缓存,前端下次请求即可拿到最新状态
515
+ await _do_health_check()
516
+ return {"status": "ok", "message": "已保存��配置已实时生效"}
517
+
518
+
519
+ # ─── 健康检查(返回后端缓存,秒级响应)────────────────────────────────────────
520
+ @app.get("/api/health")
521
+ async def health_check(force: bool = False):
522
+ """返回缓存的健康状态,force=true 时立即刷新"""
523
+ if force or time.time() - _health_cache["updated_at"] > _HEALTH_TTL:
524
+ await _do_health_check()
525
+ return _health_cache["result"]
526
+
527
+
528
+ # ─── Export results ────────────────────────────────────────────────────────
529
+ @app.get("/api/task/{task_id}/export")
530
+ def export_results(task_id: int, db: Session = Depends(get_db)):
531
+ import io
532
+ import openpyxl
533
+ from openpyxl.styles import Font, Alignment, PatternFill
534
+ from fastapi.responses import StreamingResponse
535
+
536
+ task = _get_alive_task(db, task_id)
537
+ if not task:
538
+ raise HTTPException(404, "Task not found")
539
+
540
+ results = (
541
+ db.query(MatchResult)
542
+ .filter(MatchResult.task_id == task_id)
543
+ .order_by(MatchResult.source_row_id, MatchResult.rank)
544
+ .all()
545
+ )
546
+
547
+ from openpyxl.styles import Font, PatternFill, Alignment
548
+
549
+ wb = openpyxl.Workbook()
550
+ ws = wb.active
551
+ ws.title = "匹配结果"
552
+
553
+ headers = ["源行号", "源数据内容", "候选排名", "目标候选内容", "相似度(%)", "精排分", "匹配等级", "候选来源"]
554
+ ws.append(headers)
555
+
556
+ # Header styling
557
+ header_font = Font(bold=True, color="FFFFFF")
558
+ header_fill = PatternFill(start_color="1F4E79", end_color="1F4E79", fill_type="solid")
559
+ for cell in ws[1]:
560
+ cell.font = header_font
561
+ cell.fill = header_fill
562
+ cell.alignment = Alignment(horizontal="center", vertical="center")
563
+
564
+ level_map = {"high": "高度匹配", "possible": "可能匹配", "low_confidence": "低置信", "no_match": "不匹配"}
565
+ scope_map = {"current_task_target": "目标候选集", "history": "历史数据", "standard": "标准库"}
566
+
567
+ for r in results:
568
+ src = db.query(VectorDataRow).get(r.source_row_id)
569
+ tgt = db.query(VectorDataRow).get(r.target_row_id)
570
+ ws.append([
571
+ src.row_number if src else "",
572
+ src.raw_text if src else "",
573
+ r.rank,
574
+ tgt.raw_text if tgt else "",
575
+ round(r.similarity_score * 100, 2),
576
+ round(r.rerank_score, 4) if r.rerank_score is not None else "",
577
+ level_map.get(r.match_level, r.match_level),
578
+ scope_map.get(r.candidate_scope, r.candidate_scope or ""),
579
+ ])
580
+
581
+ # Column widths
582
+ col_widths = [8, 40, 10, 40, 12, 12, 12, 14]
583
+ for i, w in enumerate(col_widths, 1):
584
+ ws.column_dimensions[chr(64 + i)].width = w
585
+
586
+ output = io.BytesIO()
587
+ wb.save(output)
588
+ output.seek(0)
589
+
590
+ return StreamingResponse(
591
+ output,
592
+ media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
593
+ headers={"Content-Disposition": f"attachment; filename=match_result_{task.task_code}.xlsx"},
594
+ )
595
+
596
+
597
+ if __name__ == "__main__":
598
+ import uvicorn
599
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
hf-vector-match-api/models.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from sqlalchemy import Column, Integer, String, Float, DateTime, Text, LargeBinary, ForeignKey
3
+ from sqlalchemy.orm import relationship
4
+ from database import Base
5
+
6
+ _TZ_BEIJING = datetime.timezone(datetime.timedelta(hours=8))
7
+
8
+ def _now_beijing():
9
+ return datetime.datetime.now(_TZ_BEIJING).replace(tzinfo=None)
10
+
11
+
12
+ class VectorMatchTask(Base):
13
+ __tablename__ = "vector_match_task"
14
+ __table_args__ = {"comment": "向量匹配任务表"}
15
+
16
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
17
+ task_code = Column(String(30), unique=True, nullable=False, index=True, comment="任务编号,格式:YYYYMMDDHHMMSSmmm")
18
+ match_mode = Column(String(50), nullable=False, default="two_file", comment="匹配模式:two_file/history/standard")
19
+ candidate_scope = Column(String(50), nullable=False, default="current_task_target", comment="候选范围:current_task_target/history/standard")
20
+ source_dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=True, comment="源数据集ID")
21
+ target_dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=True, comment="目标候选集ID")
22
+ top_k = Column(Integer, default=10, comment="每条源数据保留的Top-K候选数")
23
+ rerank_top_k = Column(Integer, default=3, comment="Reranker重排序后保留的Top-K数")
24
+ min_threshold = Column(Float, default=0.70, comment="最低相似度阈值")
25
+ status = Column(String(20), default="pending", comment="任务状态:pending/running/completed/failed")
26
+ source_row_count = Column(Integer, default=0, comment="源数据行数")
27
+ target_row_count = Column(Integer, default=0, comment="目标候选行数")
28
+ high_match_count = Column(Integer, default=0, comment="高度匹配数量(score>=0.90)")
29
+ low_confidence_count = Column(Integer, default=0, comment="低置信数量(score<0.70)")
30
+ reused_vectors = Column(Integer, default=0, comment="通过text_hash复用的向量数")
31
+ new_vectors = Column(Integer, default=0, comment="新生成的向量数")
32
+ progress_parse_source = Column(Integer, default=0, comment="解析源数据集进度(0-100)")
33
+ progress_parse_target = Column(Integer, default=0, comment="解析目标候选集进度(0-100)")
34
+ progress_vectorize = Column(Integer, default=0, comment="向量化进度(0-100)")
35
+ progress_load_candidates = Column(Integer, default=0, comment="加载候选范围进度(0-100)")
36
+ progress_similarity = Column(Integer, default=0, comment="相似度计算进度(0-100)")
37
+ progress_rerank = Column(Integer, default=0, comment="Reranker重排序进度(0-100)")
38
+ progress_save_results = Column(Integer, default=0, comment="保存结果进度(0-100)")
39
+ created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
40
+ updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
41
+ is_archived = Column(Integer, default=0, comment="是否归档:0=未归档,1=已归档")
42
+ is_delete = Column(Integer, default=0, comment="是否删除:0=未删除,1=已删除")
43
+
44
+ source_dataset = relationship("VectorDataset", foreign_keys=[source_dataset_id])
45
+ target_dataset = relationship("VectorDataset", foreign_keys=[target_dataset_id])
46
+ results = relationship("MatchResult", back_populates="task")
47
+
48
+
49
+ class VectorDataset(Base):
50
+ __tablename__ = "vector_dataset"
51
+ __table_args__ = {"comment": "向量数据集表(上传或逻辑数据集)"}
52
+
53
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
54
+ task_id = Column(Integer, ForeignKey("vector_match_task.id"), nullable=True, comment="所属任务ID")
55
+ name = Column(String(255), nullable=False, comment="数据集名称")
56
+ file_name = Column(String(255), nullable=True, comment="上传文件名")
57
+ sheet_name = Column(String(100), nullable=True, comment="Excel工作表名")
58
+ dataset_role = Column(String(20), nullable=False, comment="数据集角色:source(源)/target(目标候选)")
59
+ data_scope = Column(String(20), default="task", comment="数据范围:task/history/standard")
60
+ vector_fields = Column(Text, nullable=True, comment="参与向量化的字段列表(JSON)")
61
+ row_count = Column(Integer, default=0, comment="数据行数")
62
+ is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
63
+ created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
64
+ updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
65
+
66
+ rows = relationship("VectorDataRow", back_populates="dataset")
67
+
68
+
69
+ class VectorDataRow(Base):
70
+ __tablename__ = "vector_data_row"
71
+ __table_args__ = {"comment": "向量数据行表(单行物料/申报项等)"}
72
+
73
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
74
+ dataset_id = Column(Integer, ForeignKey("vector_dataset.id"), nullable=False, index=True, comment="所属数据集ID")
75
+ task_id = Column(Integer, nullable=True, index=True, comment="所属任务ID")
76
+ dataset_role = Column(String(20), nullable=False, index=True, comment="数据集角色:source/target")
77
+ data_scope = Column(String(20), default="task", index=True, comment="数据范围:task/history/standard")
78
+ row_number = Column(Integer, nullable=False, comment="Excel中的行号")
79
+ raw_text = Column(Text, nullable=False, comment="拼接后的原始文本")
80
+ text_hash = Column(String(64), nullable=True, index=True, comment="文本SHA256哈希,用于向量复用")
81
+ field_values = Column(Text, nullable=True, comment="各字段值(JSON)")
82
+ is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
83
+ created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
84
+ updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
85
+
86
+ dataset = relationship("VectorDataset", back_populates="rows")
87
+ embedding = relationship("VectorEmbedding", back_populates="data_row", uselist=False)
88
+
89
+
90
+ class VectorEmbedding(Base):
91
+ __tablename__ = "vector_embedding"
92
+ __table_args__ = {"comment": "向量嵌入表(与数据行一对一)"}
93
+
94
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
95
+ data_row_id = Column(Integer, ForeignKey("vector_data_row.id"), unique=True, nullable=False, index=True, comment="关联 vector_data_row.id")
96
+ text_hash = Column(String(64), nullable=False, index=True, comment="与行一致的文本哈希")
97
+ embedding = Column(LargeBinary(length=65536), nullable=False, comment="float32 数组二进制存储")
98
+ model_name = Column(String(100), nullable=True, comment="生成向量所用模型名")
99
+ dimension = Column(Integer, nullable=True, comment="向量维度")
100
+ is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
101
+ created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
102
+ updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
103
+
104
+ data_row = relationship("VectorDataRow", back_populates="embedding")
105
+
106
+
107
+ class MatchResult(Base):
108
+ __tablename__ = "match_result"
109
+ __table_args__ = {"comment": "匹配结果表(源行与候选的关联及得分)"}
110
+
111
+ id = Column(Integer, primary_key=True, autoincrement=True, comment="主键ID")
112
+ task_id = Column(Integer, ForeignKey("vector_match_task.id"), nullable=False, index=True, comment="所属任务ID")
113
+ source_row_id = Column(Integer, ForeignKey("vector_data_row.id"), nullable=False, comment="源数据行ID")
114
+ target_row_id = Column(Integer, ForeignKey("vector_data_row.id"), nullable=False, comment="目标候选行ID")
115
+ similarity_score = Column(Float, nullable=False, comment="余弦相似度分数(0-1)")
116
+ rerank_score = Column(Float, nullable=True, comment="Reranker精排分数,越高越相关")
117
+ rank = Column(Integer, nullable=False, comment="排名(1=最相似)")
118
+ rerank_rank = Column(Integer, nullable=True, comment="Reranker重排后的排名")
119
+ candidate_scope = Column(String(50), nullable=True, comment="候选来源范围")
120
+ match_level = Column(String(20), nullable=True, comment="匹配等级:high/possible/low_confidence/no_match")
121
+ is_confirmed = Column(Integer, default=0, comment="是否已人工确认:0=未确认,1=已确认,-1=已忽略")
122
+ is_delete = Column(Integer, default=0, nullable=False, index=True, comment="软删除标记:0=有效,1=已删除")
123
+ created_time = Column(DateTime, default=_now_beijing, comment="创建时间")
124
+ updated_time = Column(DateTime, default=_now_beijing, onupdate=_now_beijing, comment="更新时间")
125
+
126
+ task = relationship("VectorMatchTask", back_populates="results")
127
+ source_row = relationship("VectorDataRow", foreign_keys=[source_row_id])
128
+ target_row = relationship("VectorDataRow", foreign_keys=[target_row_id])
hf-vector-match-api/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn==0.30.6
3
+ sqlalchemy==2.0.35
4
+ python-multipart==0.0.12
5
+ openpyxl==3.1.5
6
+ numpy==1.24.4
7
+ pandas==2.0.3
8
+ httpx==0.26.0
9
+ pydantic==2.9.2
10
+ psycopg2-binary==2.9.9
11
+ pymysql==1.1.1
12
+ cryptography==43.0.1
13
+ python-dotenv==1.0.1
hf-vector-match-api/schemas.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ from datetime import datetime
5
+
6
+
7
+ class DatasetInfo(BaseModel):
8
+ id: int
9
+ name: str
10
+ file_name: Optional[str] = None
11
+ sheet_name: Optional[str] = None
12
+ dataset_role: str
13
+ data_scope: str
14
+ vector_fields: Optional[str] = None
15
+ row_count: int = 0
16
+
17
+ class Config:
18
+ from_attributes = True
19
+
20
+
21
+ class TaskCreate(BaseModel):
22
+ match_mode: str = "two_file"
23
+ top_k: int = 3
24
+ min_threshold: float = 0.70
25
+ candidate_scope: str = "current_task_target"
26
+
27
+
28
+ class TaskProgress(BaseModel):
29
+ id: int
30
+ task_code: str
31
+ status: str
32
+ source_row_count: int
33
+ target_row_count: int
34
+ reused_vectors: int
35
+ new_vectors: int
36
+ progress_parse_source: int
37
+ progress_parse_target: int
38
+ progress_vectorize: int
39
+ progress_load_candidates: int
40
+ progress_similarity: int
41
+ progress_rerank: int = 0
42
+ progress_save_results: int
43
+
44
+ class Config:
45
+ from_attributes = True
46
+
47
+
48
+ class TaskDetail(BaseModel):
49
+ id: int
50
+ task_code: str
51
+ match_mode: str
52
+ candidate_scope: str
53
+ top_k: int
54
+ min_threshold: float
55
+ status: str
56
+ source_row_count: int
57
+ target_row_count: int
58
+ high_match_count: int
59
+ low_confidence_count: int
60
+ reused_vectors: int
61
+ new_vectors: int
62
+ source_dataset: Optional[DatasetInfo] = None
63
+ target_dataset: Optional[DatasetInfo] = None
64
+ created_time: Optional[datetime] = None
65
+ updated_time: Optional[datetime] = None
66
+
67
+ class Config:
68
+ from_attributes = True
69
+
70
+
71
+ class TaskListItem(BaseModel):
72
+ id: int
73
+ task_code: str
74
+ match_mode: str
75
+ candidate_scope: str
76
+ source_dataset_name: Optional[str] = None
77
+ target_dataset_name: Optional[str] = None
78
+ status: str
79
+ is_archived: int = 0
80
+ is_delete: int = 0
81
+ created_time: Optional[datetime] = None
82
+
83
+
84
+ class MatchResultItem(BaseModel):
85
+ id: int
86
+ source_row_id: int
87
+ source_row_number: int
88
+ source_text: str
89
+ target_text: str
90
+ similarity_score: float
91
+ rerank_score: Optional[float] = None
92
+ match_level: str
93
+ candidate_scope: Optional[str] = None
94
+ is_confirmed: int = 0
95
+
96
+
97
+ class MatchResultPage(BaseModel):
98
+ items: List[MatchResultItem]
99
+ total: int
100
+ page: int
101
+ page_size: int
102
+
103
+
104
+ class CandidateDetail(BaseModel):
105
+ rank: int
106
+ rerank_rank: Optional[int] = None
107
+ target_row_id: int
108
+ target_text: str
109
+ similarity_score: float
110
+ rerank_score: Optional[float] = None
111
+ match_level: str
112
+ dataset_role: str
113
+ candidate_scope: Optional[str] = None
114
+ data_row_id: int
115
+ is_confirmed: int = 0
116
+
117
+
118
+ class SourceWithCandidates(BaseModel):
119
+ source_row_id: int
120
+ source_text: str
121
+ source_row_number: int
122
+ dataset_role: str
123
+ data_row_id: int
124
+ candidates: List[CandidateDetail]
125
+
126
+
127
+ class SheetInfo(BaseModel):
128
+ sheet_names: List[str]
129
+ columns: dict
130
+
131
+
132
+ class UploadResponse(BaseModel):
133
+ dataset_id: int
134
+ file_name: str
135
+ sheet_names: List[str]
136
+ columns: dict
137
+ all_columns: dict = {}
138
+
139
+
140
+ class SettingItem(BaseModel):
141
+ key: str
142
+ value: str
143
+
144
+
145
+ class SettingsResponse(BaseModel):
146
+ settings: dict
hf-vector-match-api/services/__init__.py ADDED
File without changes
hf-vector-match-api/services/embedding_service.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import hashlib
4
+ import numpy as np
5
+ from typing import List, Optional, Dict
6
+ import httpx
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
10
+
11
+ EMBEDDING_API_URL = os.environ.get("EMBEDDING_API_URL", "https://api.siliconflow.cn/v1/embeddings")
12
+ # EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "Qwen/Qwen3-VL-Embedding-8B")
13
+ # EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "4096"))
14
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-m3")
15
+ EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", "1024"))
16
+ EMBEDDING_PROVIDER = os.environ.get("EMBEDDING_PROVIDER", "siliconflow")
17
+ SILICONFLOW_API_KEY = os.environ.get("SILICONFLOW_API_KEY", "")
18
+
19
+ RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "Qwen/Qwen3-VL-Reranker-8B")
20
+ RERANKER_API_URL = os.environ.get("RERANKER_API_URL", "https://api.siliconflow.cn/v1/rerank")
21
+ RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true"
22
+
23
+ # Bypass system proxy for SiliconFlow API calls
24
+ if "NO_PROXY" not in os.environ:
25
+ os.environ["NO_PROXY"] = "api.siliconflow.cn"
26
+ elif "siliconflow" not in os.environ.get("NO_PROXY", ""):
27
+ os.environ["NO_PROXY"] = os.environ["NO_PROXY"] + ",api.siliconflow.cn"
28
+
29
+
30
+ def _build_simple_embedding(text: str, dim: int = 768) -> np.ndarray:
31
+ """Fallback: deterministic pseudo-embedding based on character hashing.
32
+ Only for testing when no real embedding API is available."""
33
+ h = hashlib.sha512(text.encode("utf-8")).digest()
34
+ seed = int.from_bytes(h[:4], "big")
35
+ rng = np.random.RandomState(seed)
36
+ vec = rng.randn(dim).astype(np.float32)
37
+ norm = np.linalg.norm(vec)
38
+ if norm > 0:
39
+ vec = vec / norm
40
+ return vec
41
+
42
+
43
+ async def get_embeddings_batch(texts: List[str], model: Optional[str] = None) -> List[np.ndarray]:
44
+ """Generate embeddings for a batch of texts."""
45
+ model = model or EMBEDDING_MODEL
46
+ provider = EMBEDDING_PROVIDER.lower()
47
+
48
+ if provider == "siliconflow":
49
+ return await _siliconflow_embeddings(texts, model)
50
+ elif provider == "ollama":
51
+ return await _ollama_embeddings(texts, model)
52
+ elif provider == "openai":
53
+ return await _openai_embeddings(texts, model)
54
+ else:
55
+ return [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
56
+
57
+
58
+ async def _siliconflow_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
59
+ """Call SiliconFlow (硅基流动) embedding API.
60
+ API docs: https://docs.siliconflow.cn/api-reference/embeddings
61
+ Compatible with OpenAI format, supports batch input."""
62
+ api_url = EMBEDDING_API_URL or "https://api.siliconflow.cn/v1/embeddings"
63
+ api_key = SILICONFLOW_API_KEY
64
+ if not api_key:
65
+ print("[WARN] SILICONFLOW_API_KEY not set, falling back to pseudo embeddings")
66
+ return [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
67
+
68
+ results = []
69
+ try:
70
+ async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
71
+ # SiliconFlow supports batch, but limit to 64 per request
72
+ for i in range(0, len(texts), 64):
73
+ batch = texts[i : i + 64]
74
+ resp = await client.post(
75
+ api_url,
76
+ headers={
77
+ "Authorization": f"Bearer {api_key}",
78
+ "Content-Type": "application/json",
79
+ },
80
+ json={"model": model, "input": batch, "encoding_format": "float"},
81
+ )
82
+ if resp.status_code == 200:
83
+ data = resp.json()
84
+ for item in sorted(data["data"], key=lambda x: x["index"]):
85
+ vec = np.array(item["embedding"], dtype=np.float32)
86
+ results.append(vec)
87
+ else:
88
+ print(f"[ERROR] SiliconFlow API returned {resp.status_code}: {resp.text[:200]}")
89
+ results.extend([_build_simple_embedding(t, EMBEDDING_DIM) for t in batch])
90
+ except Exception as e:
91
+ print(f"[ERROR] SiliconFlow API call failed: {e}")
92
+ results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
93
+ return results
94
+
95
+
96
+ async def _ollama_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
97
+ """Call Ollama embedding API."""
98
+ results = []
99
+ try:
100
+ async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
101
+ for text in texts:
102
+ resp = await client.post(
103
+ EMBEDDING_API_URL,
104
+ json={"model": model, "input": text}
105
+ )
106
+ if resp.status_code == 200:
107
+ data = resp.json()
108
+ if "embeddings" in data:
109
+ vec = np.array(data["embeddings"][0], dtype=np.float32)
110
+ elif "embedding" in data:
111
+ vec = np.array(data["embedding"], dtype=np.float32)
112
+ else:
113
+ vec = _build_simple_embedding(text, EMBEDDING_DIM)
114
+ results.append(vec)
115
+ else:
116
+ results.append(_build_simple_embedding(text, EMBEDDING_DIM))
117
+ except Exception:
118
+ results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
119
+ return results
120
+
121
+
122
+ async def _openai_embeddings(texts: List[str], model: str) -> List[np.ndarray]:
123
+ """Call OpenAI-compatible embedding API (e.g., vLLM)."""
124
+ api_url = os.environ.get("OPENAI_API_BASE", "http://localhost:8000") + "/v1/embeddings"
125
+ api_key = os.environ.get("OPENAI_API_KEY", "no-key")
126
+ results = []
127
+ try:
128
+ async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
129
+ resp = await client.post(
130
+ api_url,
131
+ headers={"Authorization": f"Bearer {api_key}"},
132
+ json={"model": model, "input": texts}
133
+ )
134
+ if resp.status_code == 200:
135
+ data = resp.json()
136
+ for item in data["data"]:
137
+ vec = np.array(item["embedding"], dtype=np.float32)
138
+ results.append(vec)
139
+ else:
140
+ results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
141
+ except Exception:
142
+ results = [_build_simple_embedding(t, EMBEDDING_DIM) for t in texts]
143
+ return results
144
+
145
+
146
+ async def rerank_candidates(
147
+ query: str,
148
+ documents: List[str],
149
+ top_n: Optional[int] = None,
150
+ model: Optional[str] = None,
151
+ ) -> List[Dict]:
152
+ """Call SiliconFlow Reranker API (Qwen/Qwen3-VL-Reranker-8B).
153
+ Returns list of {"index": int, "relevance_score": float} sorted by score desc."""
154
+ model = model or RERANKER_MODEL
155
+ api_key = SILICONFLOW_API_KEY
156
+
157
+ if not api_key or not RERANKER_ENABLED:
158
+ return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
159
+
160
+ if not documents:
161
+ return []
162
+
163
+ top_n = top_n or len(documents)
164
+
165
+ try:
166
+ async with httpx.AsyncClient(timeout=120.0, proxies={}) as client:
167
+ resp = await client.post(
168
+ RERANKER_API_URL,
169
+ headers={
170
+ "Authorization": f"Bearer {api_key}",
171
+ "Content-Type": "application/json",
172
+ },
173
+ json={
174
+ "model": model,
175
+ "query": query,
176
+ "documents": documents,
177
+ "top_n": top_n,
178
+ "return_documents": False,
179
+ },
180
+ )
181
+ if resp.status_code == 200:
182
+ data = resp.json()
183
+ results = data.get("results", [])
184
+ return sorted(results, key=lambda x: x["relevance_score"], reverse=True)
185
+ else:
186
+ print(f"[ERROR] Reranker API returned {resp.status_code}: {resp.text[:200]}")
187
+ return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
188
+ except Exception as e:
189
+ print(f"[ERROR] Reranker API call failed: {e}")
190
+ return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
191
+
192
+
193
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
194
+ """Compute cosine similarity between two vectors."""
195
+ norm_a = np.linalg.norm(a)
196
+ norm_b = np.linalg.norm(b)
197
+ if norm_a == 0 or norm_b == 0:
198
+ return 0.0
199
+ return float(np.dot(a, b) / (norm_a * norm_b))
200
+
201
+
202
+ def batch_cosine_similarity(source_vecs: np.ndarray, target_vecs: np.ndarray) -> np.ndarray:
203
+ """Compute pairwise cosine similarity matrix.
204
+ source_vecs: (M, D), target_vecs: (N, D)
205
+ Returns: (M, N) similarity matrix"""
206
+ source_norms = np.linalg.norm(source_vecs, axis=1, keepdims=True)
207
+ target_norms = np.linalg.norm(target_vecs, axis=1, keepdims=True)
208
+ source_norms = np.where(source_norms == 0, 1, source_norms)
209
+ target_norms = np.where(target_norms == 0, 1, target_norms)
210
+ source_normed = source_vecs / source_norms
211
+ target_normed = target_vecs / target_norms
212
+ return source_normed @ target_normed.T
213
+
214
+
215
+ def embedding_to_bytes(vec: np.ndarray) -> bytes:
216
+ return vec.astype(np.float32).tobytes()
217
+
218
+
219
+ def bytes_to_embedding(data: bytes) -> np.ndarray:
220
+ return np.frombuffer(data, dtype=np.float32)
221
+
222
+
223
+ def get_match_level(score: float) -> str:
224
+ if score >= 0.90:
225
+ return "high"
226
+ elif score >= 0.80:
227
+ return "possible"
228
+ elif score >= 0.70:
229
+ return "low_confidence"
230
+ else:
231
+ return "no_match"
hf-vector-match-api/services/excel_service.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import json
4
+ import openpyxl
5
+ from typing import List, Dict, Tuple
6
+
7
+
8
+ UPLOAD_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "uploads")
9
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
10
+
11
+
12
+ def save_upload_file(file_bytes: bytes, filename: str) -> str:
13
+ filepath = os.path.join(UPLOAD_DIR, filename)
14
+ with open(filepath, "wb") as f:
15
+ f.write(file_bytes)
16
+ return filepath
17
+
18
+
19
+ EXCLUDED_FIELDS = {"序号", "行号", "编号", "id", "ID", "Id", "no", "No", "NO", "行", "#"}
20
+
21
+ def get_sheet_info(filepath: str) -> Dict:
22
+ wb = openpyxl.load_workbook(filepath, read_only=True)
23
+ result = {"sheet_names": wb.sheetnames, "columns": {}, "all_columns": {}}
24
+ for sheet_name in wb.sheetnames:
25
+ ws = wb[sheet_name]
26
+ headers = []
27
+ for row in ws.iter_rows(min_row=1, max_row=1, values_only=True):
28
+ headers = [str(c) if c else f"列{i+1}" for i, c in enumerate(row)]
29
+ result["all_columns"][sheet_name] = headers
30
+ result["columns"][sheet_name] = [
31
+ h for h in headers if h.strip() not in EXCLUDED_FIELDS
32
+ ]
33
+ wb.close()
34
+ return result
35
+
36
+
37
+ def parse_excel_rows(
38
+ filepath: str,
39
+ sheet_name: str,
40
+ vector_fields: List[str],
41
+ ) -> List[Dict]:
42
+ wb = openpyxl.load_workbook(filepath, read_only=True)
43
+ ws = wb[sheet_name]
44
+ rows_data = []
45
+ headers = []
46
+ for row_idx, row in enumerate(ws.iter_rows(values_only=True)):
47
+ if row_idx == 0:
48
+ headers = [str(c) if c else f"列{i+1}" for i, c in enumerate(row)]
49
+ continue
50
+ row_dict = {}
51
+ for i, val in enumerate(row):
52
+ if i < len(headers):
53
+ row_dict[headers[i]] = str(val) if val is not None else ""
54
+
55
+ text_parts = []
56
+ for field in vector_fields:
57
+ if field in row_dict and row_dict[field]:
58
+ text_parts.append(row_dict[field])
59
+ raw_text = " ".join(text_parts)
60
+
61
+ if not raw_text.strip():
62
+ continue
63
+
64
+ text_hash = hashlib.sha256(raw_text.encode("utf-8")).hexdigest()
65
+
66
+ rows_data.append({
67
+ "row_number": row_idx + 1,
68
+ "raw_text": raw_text,
69
+ "text_hash": text_hash,
70
+ "field_values": json.dumps(row_dict, ensure_ascii=False),
71
+ })
72
+ wb.close()
73
+ return rows_data
hf-vector-match-api/services/match_service.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import numpy as np
3
+ from typing import List, Dict, Optional
4
+ from sqlalchemy.orm import Session
5
+
6
+ from models import (
7
+ VectorMatchTask, VectorDataset, VectorDataRow,
8
+ VectorEmbedding, MatchResult
9
+ )
10
+ from services.embedding_service import (
11
+ get_embeddings_batch, batch_cosine_similarity,
12
+ embedding_to_bytes, bytes_to_embedding, get_match_level,
13
+ rerank_candidates, RERANKER_ENABLED
14
+ )
15
+
16
+
17
+ BATCH_SIZE = 32
18
+
19
+
20
+ def _safe_commit(db):
21
+ """提交事务,连接断开时自动回滚并重试"""
22
+ try:
23
+ db.commit()
24
+ except Exception:
25
+ db.rollback()
26
+ try:
27
+ db.commit()
28
+ except Exception:
29
+ db.rollback()
30
+
31
+
32
+ async def run_match_task(task_id: int, db_factory):
33
+ """Main matching pipeline: parse → vectorize → match → save results."""
34
+ db: Session = db_factory()
35
+ try:
36
+ task = db.query(VectorMatchTask).get(task_id)
37
+ if not task:
38
+ return
39
+ task.status = "running"
40
+ _safe_commit(db)
41
+
42
+ # Step 1: Parse source
43
+ task.progress_parse_source = 100
44
+ _safe_commit(db)
45
+
46
+ # Step 2: Parse target
47
+ task.progress_parse_target = 100
48
+ _safe_commit(db)
49
+
50
+ # Step 3: Vectorize
51
+ source_rows = (
52
+ db.query(VectorDataRow)
53
+ .filter(VectorDataRow.dataset_id == task.source_dataset_id)
54
+ .all()
55
+ )
56
+ target_rows = (
57
+ db.query(VectorDataRow)
58
+ .filter(VectorDataRow.dataset_id == task.target_dataset_id)
59
+ .all()
60
+ )
61
+
62
+ task.source_row_count = len(source_rows)
63
+ task.target_row_count = len(target_rows)
64
+ _safe_commit(db)
65
+
66
+ all_rows = source_rows + target_rows
67
+ reused = 0
68
+ new_count = 0
69
+
70
+ for i in range(0, len(all_rows), BATCH_SIZE):
71
+ batch = all_rows[i : i + BATCH_SIZE]
72
+ texts_to_embed = []
73
+ rows_to_embed = []
74
+
75
+ for row in batch:
76
+ existing = (
77
+ db.query(VectorEmbedding)
78
+ .filter(VectorEmbedding.text_hash == row.text_hash)
79
+ .first()
80
+ )
81
+ if existing and existing.data_row_id != row.id:
82
+ new_emb = VectorEmbedding(
83
+ data_row_id=row.id,
84
+ text_hash=row.text_hash,
85
+ embedding=existing.embedding,
86
+ model_name=existing.model_name,
87
+ dimension=existing.dimension,
88
+ )
89
+ db.add(new_emb)
90
+ reused += 1
91
+ elif existing:
92
+ reused += 1
93
+ else:
94
+ texts_to_embed.append(row.raw_text)
95
+ rows_to_embed.append(row)
96
+
97
+ if texts_to_embed:
98
+ embeddings = await get_embeddings_batch(texts_to_embed)
99
+ for row, vec in zip(rows_to_embed, embeddings):
100
+ emb = VectorEmbedding(
101
+ data_row_id=row.id,
102
+ text_hash=row.text_hash,
103
+ embedding=embedding_to_bytes(vec),
104
+ model_name="default",
105
+ dimension=len(vec),
106
+ )
107
+ db.add(emb)
108
+ new_count += 1
109
+
110
+ progress = min(100, int((i + len(batch)) / max(len(all_rows), 1) * 100))
111
+ task.progress_vectorize = progress
112
+ task.reused_vectors = reused
113
+ task.new_vectors = new_count
114
+ _safe_commit(db)
115
+
116
+ task.progress_vectorize = 100
117
+ _safe_commit(db)
118
+
119
+ # Step 4: Load candidate range
120
+ task.progress_load_candidates = 100
121
+ _safe_commit(db)
122
+
123
+ # Step 5: Similarity calculation
124
+ source_embeddings = []
125
+ source_row_ids = []
126
+ for row in source_rows:
127
+ emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
128
+ if emb:
129
+ source_embeddings.append(bytes_to_embedding(emb.embedding))
130
+ source_row_ids.append(row.id)
131
+
132
+ target_embeddings = []
133
+ target_row_ids = []
134
+ for row in target_rows:
135
+ emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
136
+ if emb:
137
+ target_embeddings.append(bytes_to_embedding(emb.embedding))
138
+ target_row_ids.append(row.id)
139
+
140
+ if not source_embeddings or not target_embeddings:
141
+ task.status = "completed"
142
+ task.progress_similarity = 100
143
+ task.progress_save_results = 100
144
+ _safe_commit(db)
145
+ return
146
+
147
+ source_matrix = np.stack(source_embeddings)
148
+ target_matrix = np.stack(target_embeddings)
149
+
150
+ sim_matrix = batch_cosine_similarity(source_matrix, target_matrix)
151
+ task.progress_similarity = 100
152
+ _safe_commit(db)
153
+
154
+ # Step 6: Collect Top-K candidates per source row
155
+ # top_k 为初始候选数,rerank_top_k 为重排序后保留数
156
+ initial_k = task.top_k
157
+ initial_k = min(initial_k, len(target_row_ids))
158
+
159
+ # Build raw_text lookup for reranker
160
+ source_text_map = {}
161
+ target_text_map = {}
162
+ if RERANKER_ENABLED:
163
+ for row in source_rows:
164
+ source_text_map[row.id] = row.raw_text
165
+ for row in target_rows:
166
+ target_text_map[row.id] = row.raw_text
167
+
168
+ high_count = 0
169
+ low_count = 0
170
+ total_source = len(source_row_ids)
171
+
172
+ for idx, src_id in enumerate(source_row_ids):
173
+ scores = sim_matrix[idx]
174
+ top_indices = np.argsort(scores)[::-1][:initial_k]
175
+
176
+ candidates = []
177
+ for tgt_idx in top_indices:
178
+ candidates.append({
179
+ "tgt_idx": tgt_idx,
180
+ "tgt_row_id": target_row_ids[tgt_idx],
181
+ "sim_score": float(scores[tgt_idx]),
182
+ "rerank_score": None,
183
+ })
184
+
185
+ # Step 6.5: Rerank candidates
186
+ if RERANKER_ENABLED and candidates:
187
+ query_text = source_text_map.get(src_id, "")
188
+ doc_texts = [target_text_map.get(c["tgt_row_id"], "") for c in candidates]
189
+ try:
190
+ rerank_top_k = task.rerank_top_k or task.top_k
191
+ rerank_results = await rerank_candidates(
192
+ query=query_text,
193
+ documents=doc_texts,
194
+ top_n=rerank_top_k,
195
+ )
196
+ # Map rerank scores back to candidates
197
+ for rr in rerank_results:
198
+ orig_idx = rr["index"]
199
+ if orig_idx < len(candidates):
200
+ candidates[orig_idx]["rerank_score"] = rr["relevance_score"]
201
+
202
+ # Sort by rerank_score (desc), keep rerank_top_k
203
+ candidates.sort(
204
+ key=lambda c: c["rerank_score"] if c["rerank_score"] is not None else -1,
205
+ reverse=True,
206
+ )
207
+ candidates = candidates[:rerank_top_k]
208
+ except Exception as e:
209
+ print(f"[WARN] Rerank failed for source {src_id}: {e}")
210
+ candidates = candidates[:task.top_k]
211
+
212
+ progress = min(100, int((idx + 1) / total_source * 100))
213
+ task.progress_rerank = progress
214
+ if idx % 20 == 0:
215
+ _safe_commit(db)
216
+ else:
217
+ candidates = candidates[:task.top_k]
218
+
219
+ # Save results
220
+ for rank, c in enumerate(candidates):
221
+ level = get_match_level(c["sim_score"])
222
+ result = MatchResult(
223
+ task_id=task.id,
224
+ source_row_id=src_id,
225
+ target_row_id=c["tgt_row_id"],
226
+ similarity_score=c["sim_score"],
227
+ rerank_score=c["rerank_score"],
228
+ rank=rank + 1,
229
+ rerank_rank=rank + 1 if c["rerank_score"] is not None else None,
230
+ candidate_scope=task.candidate_scope,
231
+ match_level=level,
232
+ )
233
+ db.add(result)
234
+
235
+ if rank == 0:
236
+ if c["sim_score"] >= 0.90:
237
+ high_count += 1
238
+ elif c["sim_score"] < 0.70:
239
+ low_count += 1
240
+
241
+ progress = min(100, int((idx + 1) / total_source * 100))
242
+ task.progress_save_results = progress
243
+ if idx % 50 == 0:
244
+ _safe_commit(db)
245
+
246
+ task.high_match_count = high_count
247
+ task.low_confidence_count = low_count
248
+ task.progress_rerank = 100
249
+ task.progress_save_results = 100
250
+ task.status = "completed"
251
+ _safe_commit(db)
252
+
253
+ except Exception as e:
254
+ task = db.query(VectorMatchTask).get(task_id)
255
+ if task:
256
+ task.status = "failed"
257
+ _safe_commit(db)
258
+ raise e
259
+ finally:
260
+ db.close()