File size: 2,617 Bytes
0c83cbd
1328de1
 
 
 
 
 
 
 
0c83cbd
 
1328de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Optional
from sentence_transformers import SentenceTransformer, util
import math
from starlette.middleware.base import BaseHTTPMiddleware
import time
import logging
import datetime



logger = logging.getLogger("duplicate")
logger.setLevel(logging.INFO)

app = FastAPI(title="新闻语义去重服务")

# 加载本地模型(启动服务时自动加载)
logger.info(f"{datetime.datetime.now()} Loading model...")
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
logger.info(f"{datetime.datetime.now()} Model loaded.")

class DuplicateRequest(BaseModel):
    source_text: str = Field(..., description="待检查文本")
    compare_text_list: List[str] = Field(..., description="对比文本列表")
    threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1")

class DuplicateResponse(BaseModel):
    is_duplicate: bool
    max_similarity: float
    index: Optional[int] = None


class TimerMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request, call_next):
        start = time.perf_counter()
        response = await call_next(request)
        cost = (time.perf_counter() - start) * 1000  # ms
        logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms")
        return response


BATCH_SIZE = 32  # 每批对比数量,可根据机器调整 16~64


def _check_duplicate_sync(req):
    # 只 encode 一次
    source_emb = model.encode(req.source_text, convert_to_tensor=True)

    total = len(req.compare_text_list)
    batches = math.ceil(total / BATCH_SIZE)

    for b in range(batches):
        start = b * BATCH_SIZE
        end = min(start + BATCH_SIZE, total)
        batch_texts = req.compare_text_list[start:end]

        # 编码当前批
        batch_emb = model.encode(batch_texts, convert_to_tensor=True)

        # 计算相似度
        sim_scores = util.cos_sim(source_emb, batch_emb)[0]

        # 检查是否有满足阈值的
        for i, score in enumerate(sim_scores):
            sim = float(score)
            if sim >= req.threshold:
                return {
                    "is_duplicate": True,
                    "max_similarity": sim,
                    "index": start + i
                }

    # 没有命中
    return {"is_duplicate": False, "max_similarity": 0.0, "index": None}


@app.post("/api/check_duplicate", response_model=DuplicateResponse)
async def check_duplicate(req: DuplicateRequest):
    return _check_duplicate_sync(req)
    


app.add_middleware(TimerMiddleware)