s-diffusion commited on
Commit
1328de1
·
1 Parent(s): 1e4bcab
Files changed (2) hide show
  1. app.py +81 -4
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,7 +1,84 @@
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
2
 
3
- app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Optional
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import math
6
+ from starlette.middleware.base import BaseHTTPMiddleware
7
+ import time
8
+ import logging
9
+ import datetime
10
 
 
11
 
12
+
13
+ logger = logging.getLogger("duplicate")
14
+ logger.setLevel(logging.INFO)
15
+
16
+ app = FastAPI(title="新闻语义去重服务")
17
+
18
+ # 加载本地模型(启动服务时自动加载)
19
+ logger.info(f"{datetime.datetime.now()} Loading model...")
20
+ model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
21
+ logger.info(f"{datetime.datetime.now()} Model loaded.")
22
+
23
+ class DuplicateRequest(BaseModel):
24
+ source_text: str = Field(..., description="待检查文本")
25
+ compare_text_list: List[str] = Field(..., description="对比文本列表")
26
+ threshold: float = Field(0.85, ge=0, le=1, description="相似度阈值 0~1")
27
+
28
+ class DuplicateResponse(BaseModel):
29
+ is_duplicate: bool
30
+ max_similarity: float
31
+ index: Optional[int] = None
32
+
33
+
34
+ class TimerMiddleware(BaseHTTPMiddleware):
35
+ async def dispatch(self, request, call_next):
36
+ start = time.perf_counter()
37
+ response = await call_next(request)
38
+ cost = (time.perf_counter() - start) * 1000 # ms
39
+ logger.info(f"{request.method} {request.url.path} cost={cost:.2f}ms")
40
+ return response
41
+
42
+
43
+ BATCH_SIZE = 32 # 每批对比数量,可根据机器调整 16~64
44
+
45
+
46
+ def _check_duplicate_sync(req):
47
+ # 只 encode 一次
48
+ source_emb = model.encode(req.source_text, convert_to_tensor=True)
49
+
50
+ total = len(req.compare_text_list)
51
+ batches = math.ceil(total / BATCH_SIZE)
52
+
53
+ for b in range(batches):
54
+ start = b * BATCH_SIZE
55
+ end = min(start + BATCH_SIZE, total)
56
+ batch_texts = req.compare_text_list[start:end]
57
+
58
+ # 编码当前批
59
+ batch_emb = model.encode(batch_texts, convert_to_tensor=True)
60
+
61
+ # 计算相似度
62
+ sim_scores = util.cos_sim(source_emb, batch_emb)[0]
63
+
64
+ # 检查是否有满足阈值的
65
+ for i, score in enumerate(sim_scores):
66
+ sim = float(score)
67
+ if sim >= req.threshold:
68
+ return {
69
+ "is_duplicate": True,
70
+ "max_similarity": sim,
71
+ "index": start + i
72
+ }
73
+
74
+ # 没有命中
75
+ return {"is_duplicate": False, "max_similarity": 0.0, "index": None}
76
+
77
+
78
+ @app.post("/api/check_duplicate", response_model=DuplicateResponse)
79
+ async def check_duplicate(req: DuplicateRequest):
80
+ return _check_duplicate_sync(req)
81
+
82
+
83
+
84
+ app.add_middleware(TimerMiddleware)
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ sentence-transformers
4
+