|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Optional |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
import math |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
import time |
|
|
import logging |
|
|
import datetime |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("duplicate") |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
app = FastAPI(title="新闻语义去重服务") |
|
|
|
|
|
|
|
|
logger.info(f"{datetime.datetime.now()} Loading model...") |
|
|
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') |
|
|
logger.info(f"{datetime.datetime.now()} Model loaded.") |
|
|
|
|
|
class DuplicateRequest(BaseModel): |
|
|
source_text: str = Field(..., description="待检查文本") |
|
|
compare_text_list: List[str] = Field(..., description="对比文本列表") |
|
|
threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1") |
|
|
|
|
|
class DuplicateResponse(BaseModel): |
|
|
is_duplicate: bool |
|
|
max_similarity: float |
|
|
index: Optional[int] = None |
|
|
|
|
|
|
|
|
class TimerMiddleware(BaseHTTPMiddleware): |
|
|
async def dispatch(self, request, call_next): |
|
|
start = time.perf_counter() |
|
|
response = await call_next(request) |
|
|
cost = (time.perf_counter() - start) * 1000 |
|
|
logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms") |
|
|
return response |
|
|
|
|
|
|
|
|
BATCH_SIZE = 32 |
|
|
|
|
|
|
|
|
def _check_duplicate_sync(req): |
|
|
|
|
|
source_emb = model.encode(req.source_text, convert_to_tensor=True) |
|
|
|
|
|
total = len(req.compare_text_list) |
|
|
batches = math.ceil(total / BATCH_SIZE) |
|
|
|
|
|
for b in range(batches): |
|
|
start = b * BATCH_SIZE |
|
|
end = min(start + BATCH_SIZE, total) |
|
|
batch_texts = req.compare_text_list[start:end] |
|
|
|
|
|
|
|
|
batch_emb = model.encode(batch_texts, convert_to_tensor=True) |
|
|
|
|
|
|
|
|
sim_scores = util.cos_sim(source_emb, batch_emb)[0] |
|
|
|
|
|
|
|
|
for i, score in enumerate(sim_scores): |
|
|
sim = float(score) |
|
|
if sim >= req.threshold: |
|
|
return { |
|
|
"is_duplicate": True, |
|
|
"max_similarity": sim, |
|
|
"index": start + i |
|
|
} |
|
|
|
|
|
|
|
|
return {"is_duplicate": False, "max_similarity": 0.0, "index": None} |
|
|
|
|
|
|
|
|
@app.post("/api/check_duplicate", response_model=DuplicateResponse) |
|
|
async def check_duplicate(req: DuplicateRequest): |
|
|
return _check_duplicate_sync(req) |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware(TimerMiddleware) |
|
|
|