Cross-Encoder / app.py
fiewolf1000's picture
Update app.py
7346c55 verified
raw
history blame
8.29 kB
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import APIKeyQuery, APIKeyHeader
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
from typing import List, Optional
from datetime import datetime # 需在文件开头导入
from fastapi.responses import MarkdownResponse
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
class CrossEncoderWrapper:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
# 1. 初始化FastAPI应用
app = FastAPI(
title="Cross-Encoder 重排序API",
description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
version="1.0.0"
)
# 2. API Key 认证配置(支持Header或Query参数传递)
API_KEY = os.getenv("CROSS_ENCODER_API_KEY") # 生产环境从环境变量获取,避免硬编码
if not API_KEY:
raise ValueError("请先设置环境变量 CROSS_ENCODER_API_KEY")
# 支持两种认证方式:Header(推荐,更安全)或 Query(备用)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False, description="通过Header传递API Key")
api_key_query = APIKeyQuery(name="api_key", auto_error=False, description="通过URL参数传递API Key,如 ?api_key=xxx")
def get_api_key(
header_key: Optional[str] = Depends(api_key_header),
query_key: Optional[str] = Depends(api_key_query)
) -> str:
"""验证API Key,优先取Header中的值,其次取Query中的值"""
if header_key == API_KEY:
return header_key
elif query_key == API_KEY:
return query_key
raise HTTPException(
status_code=401,
detail="无效或缺失API Key(支持Header: X-API-Key 或 Query: ?api_key=xxx)",
headers={"WWW-Authenticate": "X-API-Key"}
)
# 3. 定义请求/响应数据模型(标准化格式)
class RerankRequest(BaseModel):
"""重排序请求模型"""
query: str # 用户查询(如“什么是机器学习?”)
documents: List[str] # 候选文档列表(需排序的文本集合)
top_k: Optional[int] = 5 # 需返回的Top N高相关文档,默认5
truncation: Optional[bool] = True # 是否截断过长文本,默认True
class DocumentScore(BaseModel):
"""单篇文档的排序结果(含分数)"""
document: str # 文档内容
score: float # 相关性分数(越高越相关)
rank: int # 排序名次(1为最高)
class RerankResponse(BaseModel):
"""重排序响应模型"""
request_id: str # 请求唯一标识(便于排查问题)
query: str # 回显请求的查询
top_k: int # 回显请求的Top K
results: List[DocumentScore] # 排序结果列表
model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名称
timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 用标准库生成时间戳
# 4. 加载Cross-Encoder模型(全局初始化,避免重复加载)
class CrossEncoderLoader:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 自动使用GPU(若有),否则用CPU
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval() # 推理模式,关闭Dropout
print(f"模型加载完成,使用设备:{self.device}")
def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
"""
核心重排序逻辑
:param query: 用户查询
:param documents: 候选文档列表
:param top_k: 返回Top N
:param truncation: 是否截断文本
:return: 排序后的DocumentScore列表
"""
if not documents:
raise ValueError("候选文档列表不能为空")
if top_k <= 0:
raise ValueError("top_k必须为正整数")
# 计算每篇文档的相关性分数
doc_scores = []
for doc in documents:
# 模型输入格式:query [SEP] document(SEP为模型默认分隔符)
inputs = self.tokenizer(
text=f"{query} {self.tokenizer.sep_token} {doc}",
return_tensors="pt",
padding="max_length",
truncation=truncation,
max_length=512 # 模型最大输入长度,MiniLM-L-6-v2支持512
).to(self.device)
# 推理(关闭梯度计算,提升速度)
with torch.no_grad():
outputs = self.model(**inputs)
# 模型输出的logits即为相关性分数(无需softmax,直接用原始值)
score = outputs.logits.item()
doc_scores.append((doc, score))
# 按分数降序排序,取Top K,并添加名次
sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
results = [
DocumentScore(
document=doc,
score=round(score, 4), # 分数保留4位小数,便于阅读
rank=i+1 # 名次从1开始
) for i, (doc, score) in enumerate(sorted_docs)
]
return results
# 初始化模型(全局唯一实例)
reranker = CrossEncoderLoader()
# 5. 定义API端点(标准POST接口)
@app.post(
path="/api/v1/rerank",
response_model=RerankResponse,
description="文本相关性重排序接口:输入查询和候选文档,返回Top K高相关文档及分数"
)
async def rerank_endpoint(
request: RerankRequest,
api_key: str = Depends(get_api_key) # 强制API Key认证
) -> RerankResponse:
try:
# 生成请求唯一标识(用UUID,需安装:pip install python-uuid)
import uuid
request_id = str(uuid.uuid4())
# 调用重排序逻辑
results = reranker.rerank(
query=request.query,
documents=request.documents,
top_k=request.top_k,
truncation=request.truncation
)
# 构造响应
return RerankResponse(
request_id=request_id,
query=request.query,
top_k=request.top_k,
results=results
)
except ValueError as e:
# 业务逻辑错误(如参数无效)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# 服务器内部错误(如模型加载失败)
raise HTTPException(status_code=500, detail=f"服务器内部错误:{str(e)}")
# 6. 健康检查接口(用于监控服务状态)
@app.get("/api/v1/health", description="服务健康检查接口")
async def health_check():
return {
"status": "healthy",
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"device": reranker.device,
"timestamp": str(pd.Timestamp.now())
}
# ------------------- 新增:根路径(/)首页路由 -------------------
@app.get("/", response_class=MarkdownResponse, description="API 首页(含调用指南)")
async def home_page():
"""根路径首页:展示 API 功能、调用示例、认证方式等"""
return f"""
# Cross-Encoder 重排序 API(兼容 GPT 格式)
基于 `cross-encoder/ms-marco-MiniLM-L-6-v2` 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。
## 核心功能
- 输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)
- 兼容 OpenAI 风格 API 格式,可直接用 OpenAI 库调用
- 支持 API Key 认证,保障接口安全"""
# 7. 本地运行入口(开发环境用)
if __name__ == "__main__":
import uvicorn
# 安装uvicorn:pip install uvicorn
uvicorn.run(
app="app:app",
host="0.0.0.0", # 允许外部访问
port=7860, # 端口号
reload=False # 生产环境关闭reload
)