vector-match-api / services /match_service.py
teryryy's picture
Upload 13 files
ba016aa verified
import asyncio
import numpy as np
from typing import List, Dict, Optional
from sqlalchemy.orm import Session
from models import (
VectorMatchTask, VectorDataset, VectorDataRow,
VectorEmbedding, MatchResult
)
from services.embedding_service import (
get_embeddings_batch, batch_cosine_similarity,
embedding_to_bytes, bytes_to_embedding, get_match_level,
rerank_candidates, RERANKER_ENABLED
)
BATCH_SIZE = 32
def _safe_commit(db):
"""提交事务,连接断开时自动回滚并重试"""
try:
db.commit()
except Exception:
db.rollback()
try:
db.commit()
except Exception:
db.rollback()
async def run_match_task(task_id: int, db_factory):
"""Main matching pipeline: parse → vectorize → match → save results."""
db: Session = db_factory()
try:
task = db.query(VectorMatchTask).get(task_id)
if not task:
return
task.status = "running"
_safe_commit(db)
# Step 1: Parse source
task.progress_parse_source = 100
_safe_commit(db)
# Step 2: Parse target
task.progress_parse_target = 100
_safe_commit(db)
# Step 3: Vectorize
source_rows = (
db.query(VectorDataRow)
.filter(VectorDataRow.dataset_id == task.source_dataset_id)
.all()
)
target_rows = (
db.query(VectorDataRow)
.filter(VectorDataRow.dataset_id == task.target_dataset_id)
.all()
)
task.source_row_count = len(source_rows)
task.target_row_count = len(target_rows)
_safe_commit(db)
all_rows = source_rows + target_rows
reused = 0
new_count = 0
for i in range(0, len(all_rows), BATCH_SIZE):
batch = all_rows[i : i + BATCH_SIZE]
texts_to_embed = []
rows_to_embed = []
for row in batch:
existing = (
db.query(VectorEmbedding)
.filter(VectorEmbedding.text_hash == row.text_hash)
.first()
)
if existing and existing.data_row_id != row.id:
new_emb = VectorEmbedding(
data_row_id=row.id,
text_hash=row.text_hash,
embedding=existing.embedding,
model_name=existing.model_name,
dimension=existing.dimension,
)
db.add(new_emb)
reused += 1
elif existing:
reused += 1
else:
texts_to_embed.append(row.raw_text)
rows_to_embed.append(row)
if texts_to_embed:
embeddings = await get_embeddings_batch(texts_to_embed)
for row, vec in zip(rows_to_embed, embeddings):
emb = VectorEmbedding(
data_row_id=row.id,
text_hash=row.text_hash,
embedding=embedding_to_bytes(vec),
model_name="default",
dimension=len(vec),
)
db.add(emb)
new_count += 1
progress = min(100, int((i + len(batch)) / max(len(all_rows), 1) * 100))
task.progress_vectorize = progress
task.reused_vectors = reused
task.new_vectors = new_count
_safe_commit(db)
task.progress_vectorize = 100
_safe_commit(db)
# Step 4: Load candidate range
task.progress_load_candidates = 100
_safe_commit(db)
# Step 5: Similarity calculation
source_embeddings = []
source_row_ids = []
for row in source_rows:
emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
if emb:
source_embeddings.append(bytes_to_embedding(emb.embedding))
source_row_ids.append(row.id)
target_embeddings = []
target_row_ids = []
for row in target_rows:
emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first()
if emb:
target_embeddings.append(bytes_to_embedding(emb.embedding))
target_row_ids.append(row.id)
if not source_embeddings or not target_embeddings:
task.status = "completed"
task.progress_similarity = 100
task.progress_save_results = 100
_safe_commit(db)
return
source_matrix = np.stack(source_embeddings)
target_matrix = np.stack(target_embeddings)
sim_matrix = batch_cosine_similarity(source_matrix, target_matrix)
task.progress_similarity = 100
_safe_commit(db)
# Step 6: Collect Top-K candidates per source row
# top_k 为初始候选数,rerank_top_k 为重排序后保留数
initial_k = task.top_k
initial_k = min(initial_k, len(target_row_ids))
# Build raw_text lookup for reranker
source_text_map = {}
target_text_map = {}
if RERANKER_ENABLED:
for row in source_rows:
source_text_map[row.id] = row.raw_text
for row in target_rows:
target_text_map[row.id] = row.raw_text
high_count = 0
low_count = 0
total_source = len(source_row_ids)
for idx, src_id in enumerate(source_row_ids):
scores = sim_matrix[idx]
top_indices = np.argsort(scores)[::-1][:initial_k]
candidates = []
for tgt_idx in top_indices:
candidates.append({
"tgt_idx": tgt_idx,
"tgt_row_id": target_row_ids[tgt_idx],
"sim_score": float(scores[tgt_idx]),
"rerank_score": None,
})
# Step 6.5: Rerank candidates
if RERANKER_ENABLED and candidates:
query_text = source_text_map.get(src_id, "")
doc_texts = [target_text_map.get(c["tgt_row_id"], "") for c in candidates]
try:
rerank_top_k = task.rerank_top_k or task.top_k
rerank_results = await rerank_candidates(
query=query_text,
documents=doc_texts,
top_n=rerank_top_k,
)
# Map rerank scores back to candidates
for rr in rerank_results:
orig_idx = rr["index"]
if orig_idx < len(candidates):
candidates[orig_idx]["rerank_score"] = rr["relevance_score"]
# Sort by rerank_score (desc), keep rerank_top_k
candidates.sort(
key=lambda c: c["rerank_score"] if c["rerank_score"] is not None else -1,
reverse=True,
)
candidates = candidates[:rerank_top_k]
except Exception as e:
print(f"[WARN] Rerank failed for source {src_id}: {e}")
candidates = candidates[:task.top_k]
progress = min(100, int((idx + 1) / total_source * 100))
task.progress_rerank = progress
if idx % 20 == 0:
_safe_commit(db)
else:
candidates = candidates[:task.top_k]
# Save results
for rank, c in enumerate(candidates):
level = get_match_level(c["sim_score"])
result = MatchResult(
task_id=task.id,
source_row_id=src_id,
target_row_id=c["tgt_row_id"],
similarity_score=c["sim_score"],
rerank_score=c["rerank_score"],
rank=rank + 1,
rerank_rank=rank + 1 if c["rerank_score"] is not None else None,
candidate_scope=task.candidate_scope,
match_level=level,
)
db.add(result)
if rank == 0:
if c["sim_score"] >= 0.90:
high_count += 1
elif c["sim_score"] < 0.70:
low_count += 1
progress = min(100, int((idx + 1) / total_source * 100))
task.progress_save_results = progress
if idx % 50 == 0:
_safe_commit(db)
task.high_match_count = high_count
task.low_confidence_count = low_count
task.progress_rerank = 100
task.progress_save_results = 100
task.status = "completed"
_safe_commit(db)
except Exception as e:
task = db.query(VectorMatchTask).get(task_id)
if task:
task.status = "failed"
_safe_commit(db)
raise e
finally:
db.close()