Spaces:
Running
Running
| import asyncio | |
| import numpy as np | |
| from typing import List, Dict, Optional | |
| from sqlalchemy.orm import Session | |
| from models import ( | |
| VectorMatchTask, VectorDataset, VectorDataRow, | |
| VectorEmbedding, MatchResult | |
| ) | |
| from services.embedding_service import ( | |
| get_embeddings_batch, batch_cosine_similarity, | |
| embedding_to_bytes, bytes_to_embedding, get_match_level, | |
| rerank_candidates, RERANKER_ENABLED | |
| ) | |
| BATCH_SIZE = 32 | |
| def _safe_commit(db): | |
| """提交事务,连接断开时自动回滚并重试""" | |
| try: | |
| db.commit() | |
| except Exception: | |
| db.rollback() | |
| try: | |
| db.commit() | |
| except Exception: | |
| db.rollback() | |
| async def run_match_task(task_id: int, db_factory): | |
| """Main matching pipeline: parse → vectorize → match → save results.""" | |
| db: Session = db_factory() | |
| try: | |
| task = db.query(VectorMatchTask).get(task_id) | |
| if not task: | |
| return | |
| task.status = "running" | |
| _safe_commit(db) | |
| # Step 1: Parse source | |
| task.progress_parse_source = 100 | |
| _safe_commit(db) | |
| # Step 2: Parse target | |
| task.progress_parse_target = 100 | |
| _safe_commit(db) | |
| # Step 3: Vectorize | |
| source_rows = ( | |
| db.query(VectorDataRow) | |
| .filter(VectorDataRow.dataset_id == task.source_dataset_id) | |
| .all() | |
| ) | |
| target_rows = ( | |
| db.query(VectorDataRow) | |
| .filter(VectorDataRow.dataset_id == task.target_dataset_id) | |
| .all() | |
| ) | |
| task.source_row_count = len(source_rows) | |
| task.target_row_count = len(target_rows) | |
| _safe_commit(db) | |
| all_rows = source_rows + target_rows | |
| reused = 0 | |
| new_count = 0 | |
| for i in range(0, len(all_rows), BATCH_SIZE): | |
| batch = all_rows[i : i + BATCH_SIZE] | |
| texts_to_embed = [] | |
| rows_to_embed = [] | |
| for row in batch: | |
| existing = ( | |
| db.query(VectorEmbedding) | |
| .filter(VectorEmbedding.text_hash == row.text_hash) | |
| .first() | |
| ) | |
| if existing and existing.data_row_id != row.id: | |
| new_emb = VectorEmbedding( | |
| data_row_id=row.id, | |
| text_hash=row.text_hash, | |
| embedding=existing.embedding, | |
| model_name=existing.model_name, | |
| dimension=existing.dimension, | |
| ) | |
| db.add(new_emb) | |
| reused += 1 | |
| elif existing: | |
| reused += 1 | |
| else: | |
| texts_to_embed.append(row.raw_text) | |
| rows_to_embed.append(row) | |
| if texts_to_embed: | |
| embeddings = await get_embeddings_batch(texts_to_embed) | |
| for row, vec in zip(rows_to_embed, embeddings): | |
| emb = VectorEmbedding( | |
| data_row_id=row.id, | |
| text_hash=row.text_hash, | |
| embedding=embedding_to_bytes(vec), | |
| model_name="default", | |
| dimension=len(vec), | |
| ) | |
| db.add(emb) | |
| new_count += 1 | |
| progress = min(100, int((i + len(batch)) / max(len(all_rows), 1) * 100)) | |
| task.progress_vectorize = progress | |
| task.reused_vectors = reused | |
| task.new_vectors = new_count | |
| _safe_commit(db) | |
| task.progress_vectorize = 100 | |
| _safe_commit(db) | |
| # Step 4: Load candidate range | |
| task.progress_load_candidates = 100 | |
| _safe_commit(db) | |
| # Step 5: Similarity calculation | |
| source_embeddings = [] | |
| source_row_ids = [] | |
| for row in source_rows: | |
| emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first() | |
| if emb: | |
| source_embeddings.append(bytes_to_embedding(emb.embedding)) | |
| source_row_ids.append(row.id) | |
| target_embeddings = [] | |
| target_row_ids = [] | |
| for row in target_rows: | |
| emb = db.query(VectorEmbedding).filter(VectorEmbedding.data_row_id == row.id).first() | |
| if emb: | |
| target_embeddings.append(bytes_to_embedding(emb.embedding)) | |
| target_row_ids.append(row.id) | |
| if not source_embeddings or not target_embeddings: | |
| task.status = "completed" | |
| task.progress_similarity = 100 | |
| task.progress_save_results = 100 | |
| _safe_commit(db) | |
| return | |
| source_matrix = np.stack(source_embeddings) | |
| target_matrix = np.stack(target_embeddings) | |
| sim_matrix = batch_cosine_similarity(source_matrix, target_matrix) | |
| task.progress_similarity = 100 | |
| _safe_commit(db) | |
| # Step 6: Collect Top-K candidates per source row | |
| # top_k 为初始候选数,rerank_top_k 为重排序后保留数 | |
| initial_k = task.top_k | |
| initial_k = min(initial_k, len(target_row_ids)) | |
| # Build raw_text lookup for reranker | |
| source_text_map = {} | |
| target_text_map = {} | |
| if RERANKER_ENABLED: | |
| for row in source_rows: | |
| source_text_map[row.id] = row.raw_text | |
| for row in target_rows: | |
| target_text_map[row.id] = row.raw_text | |
| high_count = 0 | |
| low_count = 0 | |
| total_source = len(source_row_ids) | |
| for idx, src_id in enumerate(source_row_ids): | |
| scores = sim_matrix[idx] | |
| top_indices = np.argsort(scores)[::-1][:initial_k] | |
| candidates = [] | |
| for tgt_idx in top_indices: | |
| candidates.append({ | |
| "tgt_idx": tgt_idx, | |
| "tgt_row_id": target_row_ids[tgt_idx], | |
| "sim_score": float(scores[tgt_idx]), | |
| "rerank_score": None, | |
| }) | |
| # Step 6.5: Rerank candidates | |
| if RERANKER_ENABLED and candidates: | |
| query_text = source_text_map.get(src_id, "") | |
| doc_texts = [target_text_map.get(c["tgt_row_id"], "") for c in candidates] | |
| try: | |
| rerank_top_k = task.rerank_top_k or task.top_k | |
| rerank_results = await rerank_candidates( | |
| query=query_text, | |
| documents=doc_texts, | |
| top_n=rerank_top_k, | |
| ) | |
| # Map rerank scores back to candidates | |
| for rr in rerank_results: | |
| orig_idx = rr["index"] | |
| if orig_idx < len(candidates): | |
| candidates[orig_idx]["rerank_score"] = rr["relevance_score"] | |
| # Sort by rerank_score (desc), keep rerank_top_k | |
| candidates.sort( | |
| key=lambda c: c["rerank_score"] if c["rerank_score"] is not None else -1, | |
| reverse=True, | |
| ) | |
| candidates = candidates[:rerank_top_k] | |
| except Exception as e: | |
| print(f"[WARN] Rerank failed for source {src_id}: {e}") | |
| candidates = candidates[:task.top_k] | |
| progress = min(100, int((idx + 1) / total_source * 100)) | |
| task.progress_rerank = progress | |
| if idx % 20 == 0: | |
| _safe_commit(db) | |
| else: | |
| candidates = candidates[:task.top_k] | |
| # Save results | |
| for rank, c in enumerate(candidates): | |
| level = get_match_level(c["sim_score"]) | |
| result = MatchResult( | |
| task_id=task.id, | |
| source_row_id=src_id, | |
| target_row_id=c["tgt_row_id"], | |
| similarity_score=c["sim_score"], | |
| rerank_score=c["rerank_score"], | |
| rank=rank + 1, | |
| rerank_rank=rank + 1 if c["rerank_score"] is not None else None, | |
| candidate_scope=task.candidate_scope, | |
| match_level=level, | |
| ) | |
| db.add(result) | |
| if rank == 0: | |
| if c["sim_score"] >= 0.90: | |
| high_count += 1 | |
| elif c["sim_score"] < 0.70: | |
| low_count += 1 | |
| progress = min(100, int((idx + 1) / total_source * 100)) | |
| task.progress_save_results = progress | |
| if idx % 50 == 0: | |
| _safe_commit(db) | |
| task.high_match_count = high_count | |
| task.low_confidence_count = low_count | |
| task.progress_rerank = 100 | |
| task.progress_save_results = 100 | |
| task.status = "completed" | |
| _safe_commit(db) | |
| except Exception as e: | |
| task = db.query(VectorMatchTask).get(task_id) | |
| if task: | |
| task.status = "failed" | |
| _safe_commit(db) | |
| raise e | |
| finally: | |
| db.close() | |