Spaces:
Running
Running
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional | |
| from sentence_transformers import SentenceTransformer, util | |
| import math | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| import time | |
| import logging | |
| import datetime | |
| logger = logging.getLogger("duplicate") | |
| logger.setLevel(logging.INFO) | |
| app = FastAPI(title="新闻语义去重服务") | |
| # 加载本地模型(启动服务时自动加载) | |
| logger.info(f"{datetime.datetime.now()} Loading model...") | |
| model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
| logger.info(f"{datetime.datetime.now()} Model loaded.") | |
| class DuplicateRequest(BaseModel): | |
| source_text: str = Field(..., description="待检查文本") | |
| compare_text_list: List[str] = Field(..., description="对比文本列表") | |
| threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1") | |
| class DuplicateResponse(BaseModel): | |
| is_duplicate: bool | |
| max_similarity: float | |
| index: Optional[int] = None | |
| class TimerMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request, call_next): | |
| start = time.perf_counter() | |
| response = await call_next(request) | |
| cost = (time.perf_counter() - start) * 1000 # ms | |
| logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms") | |
| return response | |
| BATCH_SIZE = 32 # 每批对比数量,可根据机器调整 16~64 | |
| def _check_duplicate_sync(req): | |
| # 只 encode 一次 | |
| source_emb = model.encode(req.source_text, convert_to_tensor=True) | |
| total = len(req.compare_text_list) | |
| batches = math.ceil(total / BATCH_SIZE) | |
| for b in range(batches): | |
| start = b * BATCH_SIZE | |
| end = min(start + BATCH_SIZE, total) | |
| batch_texts = req.compare_text_list[start:end] | |
| # 编码当前批 | |
| batch_emb = model.encode(batch_texts, convert_to_tensor=True) | |
| # 计算相似度 | |
| sim_scores = util.cos_sim(source_emb, batch_emb)[0] | |
| # 检查是否有满足阈值的 | |
| for i, score in enumerate(sim_scores): | |
| sim = float(score) | |
| if sim >= req.threshold: | |
| return { | |
| "is_duplicate": True, | |
| "max_similarity": sim, | |
| "index": start + i | |
| } | |
| # 没有命中 | |
| return {"is_duplicate": False, "max_similarity": 0.0, "index": None} | |
| async def check_duplicate(req: DuplicateRequest): | |
| return _check_duplicate_sync(req) | |
| app.add_middleware(TimerMiddleware) | |