Commit
·
1328de1
1
Parent(s):
1e4bcab
new app
Browse files- app.py +81 -4
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,7 +1,84 @@
|
|
| 1 |
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
app = FastAPI()
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
import math
|
| 6 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
import datetime
|
| 10 |
|
|
|
|
| 11 |
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger("duplicate")
|
| 14 |
+
logger.setLevel(logging.INFO)
|
| 15 |
+
|
| 16 |
+
app = FastAPI(title="新闻语义去重服务")
|
| 17 |
+
|
| 18 |
+
# 加载本地模型(启动服务时自动加载)
|
| 19 |
+
logger.info(f"{datetime.datetime.now()} Loading model...")
|
| 20 |
+
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
| 21 |
+
logger.info(f"{datetime.datetime.now()} Model loaded.")
|
| 22 |
+
|
| 23 |
+
class DuplicateRequest(BaseModel):
|
| 24 |
+
source_text: str = Field(..., description="待检查文本")
|
| 25 |
+
compare_text_list: List[str] = Field(..., description="对比文本列表")
|
| 26 |
+
threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1")
|
| 27 |
+
|
| 28 |
+
class DuplicateResponse(BaseModel):
|
| 29 |
+
is_duplicate: bool
|
| 30 |
+
max_similarity: float
|
| 31 |
+
index: Optional[int] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TimerMiddleware(BaseHTTPMiddleware):
|
| 35 |
+
async def dispatch(self, request, call_next):
|
| 36 |
+
start = time.perf_counter()
|
| 37 |
+
response = await call_next(request)
|
| 38 |
+
cost = (time.perf_counter() - start) * 1000 # ms
|
| 39 |
+
logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms")
|
| 40 |
+
return response
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
BATCH_SIZE = 32 # 每批对比数量,可根据机器调整 16~64
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _check_duplicate_sync(req):
|
| 47 |
+
# 只 encode 一次
|
| 48 |
+
source_emb = model.encode(req.source_text, convert_to_tensor=True)
|
| 49 |
+
|
| 50 |
+
total = len(req.compare_text_list)
|
| 51 |
+
batches = math.ceil(total / BATCH_SIZE)
|
| 52 |
+
|
| 53 |
+
for b in range(batches):
|
| 54 |
+
start = b * BATCH_SIZE
|
| 55 |
+
end = min(start + BATCH_SIZE, total)
|
| 56 |
+
batch_texts = req.compare_text_list[start:end]
|
| 57 |
+
|
| 58 |
+
# 编码当前批
|
| 59 |
+
batch_emb = model.encode(batch_texts, convert_to_tensor=True)
|
| 60 |
+
|
| 61 |
+
# 计算相似度
|
| 62 |
+
sim_scores = util.cos_sim(source_emb, batch_emb)[0]
|
| 63 |
+
|
| 64 |
+
# 检查是否有满足阈值的
|
| 65 |
+
for i, score in enumerate(sim_scores):
|
| 66 |
+
sim = float(score)
|
| 67 |
+
if sim >= req.threshold:
|
| 68 |
+
return {
|
| 69 |
+
"is_duplicate": True,
|
| 70 |
+
"max_similarity": sim,
|
| 71 |
+
"index": start + i
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# 没有命中
|
| 75 |
+
return {"is_duplicate": False, "max_similarity": 0.0, "index": None}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@app.post("/api/check_duplicate", response_model=DuplicateResponse)
|
| 79 |
+
async def check_duplicate(req: DuplicateRequest):
|
| 80 |
+
return _check_duplicate_sync(req)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
app.add_middleware(TimerMiddleware)
|
requirements.txt
CHANGED
|
@@ -1,2 +1,4 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
|
|
|
|
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
+
sentence-transformers
|
| 4 |
+
|