Cross-Encoder / app.py
fiewolf1000's picture
Update app.py
8bfbcde verified
raw
history blame
14.3 kB
import os
import uuid
from datetime import datetime
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.responses import HTMLResponse # 仅保留 HTMLResponse,删除 MarkdownResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from typing import List, Optional
# ------------------- 1. 基础配置(缓存目录 + 环境变量) -------------------
# 设置 Hugging Face 缓存目录(可写目录,解决权限问题)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
# 从环境变量获取 API Key(认证用,需在 Hugging Face Spaces 中配置)
API_KEY = os.getenv("CROSS_ENCODER_API_KEY")
if not API_KEY:
raise ValueError("请在 Hugging Face Spaces 中设置环境变量 CROSS_ENCODER_API_KEY")
# ------------------- 2. 初始化 FastAPI 应用(仅初始化一次) -------------------
app = FastAPI(
title="Cross-Encoder 重排序 API",
description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口(兼容 GPT 格式)",
version="1.0.0"
)
# ------------------- 3. API Key 认证配置(支持 Header/Query 两种方式) -------------------
# 支持 Header(推荐)和 Query(备用)传递 API Key
api_key_header = APIKeyHeader(
name="X-API-Key",
auto_error=False,
description="通过 Header 传递 API Key(推荐)"
)
api_key_query = APIKeyQuery(
name="api_key",
auto_error=False,
description="通过 URL 参数传递 API Key(如 ?api_key=xxx)"
)
def verify_api_key(
header_key: Optional[str] = Depends(api_key_header),
query_key: Optional[str] = Depends(api_key_query)
) -> str:
"""验证 API Key,优先使用 Header 中的值"""
if header_key == API_KEY:
return header_key
elif query_key == API_KEY:
return query_key
raise HTTPException(
status_code=401,
detail="无效或缺失 API Key(支持:Header: X-API-Key 或 Query: ?api_key=xxx)",
headers={"WWW-Authenticate": "X-API-Key"}
)
# ------------------- 4. 数据模型定义(请求/响应格式) -------------------
class RerankRequest(BaseModel):
"""重排序请求模型(支持基础重排序 + GPT 兼容格式)"""
query: str # 用户查询(如“什么是机器学习?”)
documents: List[str] # 候选文档列表(需排序的文本)
top_k: Optional[int] = 3 # 返回 Top N 高相关文档,默认 3
truncation: Optional[bool] = True # 是否截断过长文本(模型最大输入 512 Token)
class DocumentScore(BaseModel):
"""单篇文档的排序结果(含分数和排名)"""
document: str # 文档内容
score: float # 相关性分数(越高越相关)
rank: int # 排序名次(1 为最高)
class RerankResponse(BaseModel):
"""重排序响应模型(标准化格式)"""
request_id: str # 请求唯一标识(用于排查问题)
query: str # 回显用户查询
top_k: int # 回显返回的 Top N 数量
results: List[DocumentScore] # 排序结果列表
model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名
timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 时间戳
# GPT 兼容格式的请求模型(适配 /v1/chat/completions 接口)
class GPTMessage(BaseModel):
role: str # 仅支持 "user" 角色
content: str # 格式:"query: [查询]; documents: [文档1]; [文档2]; ..."
class GPTRequest(BaseModel):
model: str # 固定为模型名,用于兼容 GPT 调用格式
messages: List[GPTMessage] # GPT 风格的消息列表
top_k: Optional[int] = 3 # 同 RerankRequest 的 top_k
class GPTResponse(BaseModel):
"""GPT 兼容的响应模型(模仿 OpenAI 格式)"""
id: str = f"rerank-{uuid.uuid4().hex[:10]}"
object: str = "chat.completion"
created: int = int(datetime.now().timestamp())
model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
choices: List[dict] = [] # 存储排序结果
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
# ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
class CrossEncoderModel:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = model_name
# 加载分词器和模型(从缓存目录加载,避免权限问题)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
# 自动选择设备(GPU 优先,无则用 CPU)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval() # 推理模式,关闭 Dropout
print(f"模型加载完成!使用设备:{self.device}")
def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
"""核心重排序逻辑:计算查询与文档的相关性并排序"""
# 参数校验
if not documents:
raise ValueError("候选文档列表不能为空")
if top_k <= 0 or top_k > len(documents):
raise ValueError(f"top_k 需在 1~{len(documents)} 之间")
# 计算每篇文档的相关性分数
doc_scores = []
for doc in documents:
# 模型输入格式:query [SEP] document(SEP 是模型默认分隔符)
input_text = f"{query} {self.tokenizer.sep_token} {doc}"
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding="max_length",
truncation=truncation,
max_length=512 # 模型最大输入长度(MiniLM-L-6-v2 支持 512 Token)
).to(self.device)
# 推理(关闭梯度计算,提升速度)
with torch.no_grad():
outputs = self.model(**inputs)
# 模型输出的 logits 即为相关性分数(无需 softmax,直接使用原始值)
score = outputs.logits.item()
doc_scores.append((doc, score))
# 按分数降序排序,取 Top K 并生成结果
sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
return [
DocumentScore(
document=doc,
score=round(score, 4), # 分数保留 4 位小数,便于阅读
rank=i+1 # 名次从 1 开始
) for i, (doc, score) in enumerate(sorted_docs)
]
# 初始化模型(全局唯一,避免重复加载)
reranker = CrossEncoderModel()
# ------------------- 6. API 端点定义 -------------------
# 6.1 根路径首页(HTML 格式,无 Markdown 依赖)
@app.get("/", response_class=HTMLResponse, description="API 首页(含调用指南)")
async def home_page():
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<title>Cross-Encoder 重排序 API</title>
<style>
body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }}
h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
h2 {{ color: #34495e; margin-top: 30px; }}
pre {{ background: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #e9ecef; overflow-x: auto; }}
table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }}
th {{ background-color: #f1f5f9; }}
.note {{ color: #6c757d; font-size: 0.9em; }}
.api-url {{ color: #3498db; font-weight: bold; }}
</style>
</head>
<body>
<h1>Cross-Encoder 重排序 API</h1>
<p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。</p>
<h2>核心功能</h2>
<ul>
<li>输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)</li>
<li>支持两种 API 格式:基础重排序接口(/api/v1/rerank)和 GPT 兼容接口(/v1/chat/completions)</li>
<li>API Key 认证,保障接口安全</li>
</ul>
<h2>接口列表</h2>
<table>
<tr>
<th>接口名称</th>
<th>URL</th>
<th>方法</th>
<th>说明</th>
</tr>
<tr>
<td>基础重排序接口</td>
<td class="api-url">{app.root_path}/api/v1/rerank</td>
<td>POST</td>
<td>标准化重排序接口,返回结构化结果</td>
</tr>
<tr>
<td>GPT 兼容接口</td>
<td class="api-url">{app.root_path}/v1/chat/completions</td>
<td>POST</td>
<td>模仿 OpenAI 格式,可直接用 OpenAI 库调用</td>
</tr>
<tr>
<td>健康检查</td>
<td class="api-url">{app.root_path}/api/v1/health</td>
<td>GET</td>
<td>无需认证,检查服务状态</td>
</tr>
</table>
<h2>调用示例(GPT 兼容接口)</h2>
<pre><code>from openai import OpenAI
# 配置客户端(指向你的 Space 地址)
client = OpenAI(
api_key="your-api-key-here", # 替换为你的 API Key
base_url="https://&lt;your-username&gt;-&lt;your-space-name&gt;.hf.space/v1" # 替换为你的 Space URL
)
# 发送重排序请求
response = client.chat.completions.create(
model="cross-encoder/ms-marco-MiniLM-L-6-v2", # 固定模型名
messages=[
{{
"role": "user",
"content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言; 深度学习是机器学习的子集;"
}}
],
top_k=2 # 返回 Top 2 高相关文档
)
# 打印结果
print(response.choices[0].message.content)</code></pre>
<h2>API Key 认证方式</h2>
<p>所有 POST 接口需通过以下方式之一传递 API Key:</p>
<ul>
<li><strong>Header 方式(推荐)</strong>:在请求 Header 中添加 <code>X-API-Key: your-api-key</code></li>
<li><strong>Query 方式(备用)</strong>:在 URL 后添加 <code>?api_key=your-api-key</code></li>
</ul>
<p class="note">页面生成时间: {current_time} | 模型运行设备: {reranker.device}</p>
</body>
</html>
"""
# 6.2 基础重排序接口(标准化格式)
@app.post(
"/api/v1/rerank",
response_model=RerankResponse,
description="基础重排序接口,返回结构化的排序结果"
)
async def base_rerank(
request: RerankRequest,
api_key: str = Depends(verify_api_key)
):
try:
# 执行重排序
results = reranker.rerank(
query=request.query,
documents=request.documents,
top_k=request.top_k,
truncation=request.truncation
)
# 生成响应
return RerankResponse(
request_id=str(uuid.uuid4()),
query=request.query,
top_k=request.top_k,
results=results
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
# 6.3 GPT 兼容接口(模仿 OpenAI 格式)
@app.post(
"/v1/chat/completions",
response_model=GPTResponse,
description="GPT 兼容接口,支持用 OpenAI 库调用"
)
async def gpt_compatible_rerank(
request: GPTRequest,
api_key: str = Depends(verify_api_key)
):
try:
# 校验模型名(确保兼容 GPT 调用格式)
if request.model != reranker.model_name:
raise ValueError(f"仅支持模型:{reranker.model_name}")
# 校验消息(仅支持最后一条为 user 角色)
if not request.messages or request.messages[-1].role != "user":
raise ValueError("最后一条消息必须是 'user' 角色")
# 解析用户输入(从 content 中提取 query 和 documents)
content = request.messages[-1].content
if "; documents: " not in content:
raise ValueError("输入格式错误,需为:'query: [查询]; documents: [文档1]; [文档2]; ...'")
query_part, docs_part = content.split("; documents: ")
query = query_part.replace("query: ", "").strip()
documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
# 执行重排序
results = reranker.rerank(
query=query,
documents=documents,
top_k=request.top_k,
truncation=True
)
# 格式化 GPT 风格的响应
return GPTResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": f"重排序结果(按相关性降序):\n{[{'文档': r.document, '分数': r.score, '排名': r.rank} for r in results]}"
},
"finish_reason": "stop"
}]
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
# 6.4 健康检查接口
@app.get("/api/v1/health", description="服务健康检查接口(无需认证)")
async def health_check():
return {
"status": "healthy",
"model": reranker.model_name,
"device": reranker.device,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"message": "服务正常运行"
}
# ------------------- 7. 本地运行入口(开发环境用) -------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app="app:app",
host="0.0.0.0",
port=7860,
reload=False # 生产环境关闭 reload
)