File size: 2,617 Bytes
0c83cbd 1328de1 0c83cbd 1328de1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Optional
from sentence_transformers import SentenceTransformer, util
import math
from starlette.middleware.base import BaseHTTPMiddleware
import time
import logging
import datetime
logger = logging.getLogger("duplicate")
logger.setLevel(logging.INFO)
app = FastAPI(title="新闻语义去重服务")
# 加载本地模型(启动服务时自动加载)
logger.info(f"{datetime.datetime.now()} Loading model...")
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
logger.info(f"{datetime.datetime.now()} Model loaded.")
class DuplicateRequest(BaseModel):
source_text: str = Field(..., description="待检查文本")
compare_text_list: List[str] = Field(..., description="对比文本列表")
threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1")
class DuplicateResponse(BaseModel):
is_duplicate: bool
max_similarity: float
index: Optional[int] = None
class TimerMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
start = time.perf_counter()
response = await call_next(request)
cost = (time.perf_counter() - start) * 1000 # ms
logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms")
return response
BATCH_SIZE = 32 # 每批对比数量,可根据机器调整 16~64
def _check_duplicate_sync(req):
# 只 encode 一次
source_emb = model.encode(req.source_text, convert_to_tensor=True)
total = len(req.compare_text_list)
batches = math.ceil(total / BATCH_SIZE)
for b in range(batches):
start = b * BATCH_SIZE
end = min(start + BATCH_SIZE, total)
batch_texts = req.compare_text_list[start:end]
# 编码当前批
batch_emb = model.encode(batch_texts, convert_to_tensor=True)
# 计算相似度
sim_scores = util.cos_sim(source_emb, batch_emb)[0]
# 检查是否有满足阈值的
for i, score in enumerate(sim_scores):
sim = float(score)
if sim >= req.threshold:
return {
"is_duplicate": True,
"max_similarity": sim,
"index": start + i
}
# 没有命中
return {"is_duplicate": False, "max_similarity": 0.0, "index": None}
@app.post("/api/check_duplicate", response_model=DuplicateResponse)
async def check_duplicate(req: DuplicateRequest):
return _check_duplicate_sync(req)
app.add_middleware(TimerMiddleware)
|