from fastapi import FastAPI from pydantic import BaseModel, Field from typing import List, Optional from sentence_transformers import SentenceTransformer, util import math from starlette.middleware.base import BaseHTTPMiddleware import time import logging import datetime logger = logging.getLogger("duplicate") logger.setLevel(logging.INFO) app = FastAPI(title="新闻语义去重服务") # 加载本地模型(启动服务时自动加载) logger.info(f"{datetime.datetime.now()} Loading model...") model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') logger.info(f"{datetime.datetime.now()} Model loaded.") class DuplicateRequest(BaseModel): source_text: str = Field(..., description="待检查文本") compare_text_list: List[str] = Field(..., description="对比文本列表") threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1") class DuplicateResponse(BaseModel): is_duplicate: bool max_similarity: float index: Optional[int] = None class TimerMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): start = time.perf_counter() response = await call_next(request) cost = (time.perf_counter() - start) * 1000 # ms logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms") return response BATCH_SIZE = 32 # 每批对比数量,可根据机器调整 16~64 def _check_duplicate_sync(req): # 只 encode 一次 source_emb = model.encode(req.source_text, convert_to_tensor=True) total = len(req.compare_text_list) batches = math.ceil(total / BATCH_SIZE) for b in range(batches): start = b * BATCH_SIZE end = min(start + BATCH_SIZE, total) batch_texts = req.compare_text_list[start:end] # 编码当前批 batch_emb = model.encode(batch_texts, convert_to_tensor=True) # 计算相似度 sim_scores = util.cos_sim(source_emb, batch_emb)[0] # 检查是否有满足阈值的 for i, score in enumerate(sim_scores): sim = float(score) if sim >= req.threshold: return { "is_duplicate": True, "max_similarity": sim, "index": start + i } # 没有命中 return {"is_duplicate": False, "max_similarity": 0.0, "index": None} @app.post("/api/check_duplicate", response_model=DuplicateResponse) async def check_duplicate(req: DuplicateRequest): return _check_duplicate_sync(req) app.add_middleware(TimerMiddleware)