import asyncio import numpy as np from typing import List, Dict, Optional from sqlalchemy.orm import Session from models import ( VectorMatchTask, VectorDataset, VectorDataRow, VectorEmbedding, MatchResult ) from services.embedding_service import ( get_embeddings_batch, batch_cosine_similarity, embedding_to_bytes, bytes_to_embedding, get_match_level, rerank_candidates, RERANKER_ENABLED ) BATCH_SIZE = 32 def _safe_commit(db): """提交事务,连接断开时自动回滚并重试""" try: db.commit() except Exception: db.rollback() try: db.commit() except Exception: db.rollback() async def run_match_task(task_id: int, db_factory): """Main matching pipeline: parse → vectorize → match → save results.""" db: Session = db_factory() try: task = db.query(VectorMatchTask).get(task_id) if not task: return task.status = "running" _safe_commit(db) # Step 1: Parse source task.progress_parse_source = 100 _safe_commit(db) # Step 2: Parse target task.progress_parse_target = 100 _safe_commit(db) # Step 3: Vectorize source_rows = ( db.query(VectorDataRow) .filter(VectorDataRow.dataset_id == task.source_dataset_id) .all() ) target_rows = ( db.query(VectorDataRow) .filter(VectorDataRow.dataset_id == task.target_dataset_id) .all() ) task.source_row_count = len(source_rows) task.target_row_count = len(target_rows) _safe_commit(db) all_rows = source_rows + target_rows reused = 0 new_count = 0 for i in range(0, len(all_rows), BATCH_SIZE): batch = all_rows[i : i + BATCH_SIZE] texts_to_embed = [] rows_to_embed = [] for row in batch: existing = ( db.query(VectorEmbedding) .filter(VectorEmbedding.text_hash == row.text_hash) .first() ) if existing and existing.data_row_id != row.id: new_emb = VectorEmbedding( data_row_id=row.id, text_hash=row.text_hash, embedding=existing.embedding, model_name=existing.model_name, dimension=existing.dimension, ) db.add(new_emb) reused += 1 elif existing: reused += 1 else: texts_to_embed.append(row.raw_text) rows_to_embed.append(row) if texts_to_embed: embeddings = await get_embeddings_batch(texts_to_embed) for row, vec in zip(rows_to_embed, embeddings): emb = VectorEmbedding( data_row_id=row.id, text_hash=row.text_hash, embedding=embedding_to_bytes(vec), model_name="default", dimension=len(vec), ) db.add(emb) new_count += 1 progress = min(100, int((i + len(batch)) / max(len(all_rows), 1) * 100)) task.progress_vectorize = progress task.reused_vectors = reused task.new_vectors = new_count _safe_commit(db) task.progress_vectorize = 100 _safe_commit(db) # Step 4: Load candidate range task.progress_load_candidates = 100 _safe_commit(db) # Step 5: Similarity calculation source_embeddings = [] source_row_ids = [] for row in source_rows: emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first() if emb: source_embeddings.append(bytes_to_embedding(emb.embedding)) source_row_ids.append(row.id) target_embeddings = [] target_row_ids = [] for row in target_rows: emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first() if emb: target_embeddings.append(bytes_to_embedding(emb.embedding)) target_row_ids.append(row.id) if not source_embeddings or not target_embeddings: task.status = "completed" task.progress_similarity = 100 task.progress_save_results = 100 _safe_commit(db) return source_matrix = np.stack(source_embeddings) target_matrix = np.stack(target_embeddings) sim_matrix = batch_cosine_similarity(source_matrix, target_matrix) task.progress_similarity = 100 _safe_commit(db) # Step 6: Collect Top-K candidates per source row # top_k 为初始候选数,rerank_top_k 为重排序后保留数 initial_k = task.top_k initial_k = min(initial_k, len(target_row_ids)) # Build raw_text lookup for reranker source_text_map = {} target_text_map = {} if RERANKER_ENABLED: for row in source_rows: source_text_map[row.id] = row.raw_text for row in target_rows: target_text_map[row.id] = row.raw_text high_count = 0 low_count = 0 total_source = len(source_row_ids) for idx, src_id in enumerate(source_row_ids): scores = sim_matrix[idx] top_indices = np.argsort(scores)[::-1][:initial_k] candidates = [] for tgt_idx in top_indices: candidates.append({ "tgt_idx": tgt_idx, "tgt_row_id": target_row_ids[tgt_idx], "sim_score": float(scores[tgt_idx]), "rerank_score": None, }) # Step 6.5: Rerank candidates if RERANKER_ENABLED and candidates: query_text = source_text_map.get(src_id, "") doc_texts = [target_text_map.get(c["tgt_row_id"], "") for c in candidates] try: rerank_top_k = task.rerank_top_k or task.top_k rerank_results = await rerank_candidates( query=query_text, documents=doc_texts, top_n=rerank_top_k, ) # Map rerank scores back to candidates for rr in rerank_results: orig_idx = rr["index"] if orig_idx < len(candidates): candidates[orig_idx]["rerank_score"] = rr["relevance_score"] # Sort by rerank_score (desc), keep rerank_top_k candidates.sort( key=lambda c: c["rerank_score"] if c["rerank_score"] is not None else -1, reverse=True, ) candidates = candidates[:rerank_top_k] except Exception as e: print(f"[WARN] Rerank failed for source {src_id}: {e}") candidates = candidates[:task.top_k] progress = min(100, int((idx + 1) / total_source * 100)) task.progress_rerank = progress if idx % 20 == 0: _safe_commit(db) else: candidates = candidates[:task.top_k] # Save results for rank, c in enumerate(candidates): level = get_match_level(c["sim_score"]) result = MatchResult( task_id=task.id, source_row_id=src_id, target_row_id=c["tgt_row_id"], similarity_score=c["sim_score"], rerank_score=c["rerank_score"], rank=rank + 1, rerank_rank=rank + 1 if c["rerank_score"] is not None else None, candidate_scope=task.candidate_scope, match_level=level, ) db.add(result) if rank == 0: if c["sim_score"] >= 0.90: high_count += 1 elif c["sim_score"] < 0.70: low_count += 1 progress = min(100, int((idx + 1) / total_source * 100)) task.progress_save_results = progress if idx % 50 == 0: _safe_commit(db) task.high_match_count = high_count task.low_confidence_count = low_count task.progress_rerank = 100 task.progress_save_results = 100 task.status = "completed" _safe_commit(db) except Exception as e: task = db.query(VectorMatchTask).get(task_id) if task: task.status = "failed" _safe_commit(db) raise e finally: db.close()