Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Depends | |
| from fastapi.security import APIKeyQuery, APIKeyHeader | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import os | |
| from typing import List, Optional | |
| from datetime import datetime # 需在文件开头导入 | |
| from fastapi.responses import MarkdownResponse | |
| import os | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
| class CrossEncoderWrapper: | |
| def __init__(self): | |
| self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) | |
| # 1. 初始化FastAPI应用 | |
| app = FastAPI( | |
| title="Cross-Encoder 重排序API", | |
| description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口", | |
| version="1.0.0" | |
| ) | |
| # 2. API Key 认证配置(支持Header或Query参数传递) | |
| API_KEY = os.getenv("CROSS_ENCODER_API_KEY") # 生产环境从环境变量获取,避免硬编码 | |
| if not API_KEY: | |
| raise ValueError("请先设置环境变量 CROSS_ENCODER_API_KEY") | |
| # 支持两种认证方式:Header(推荐,更安全)或 Query(备用) | |
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False, description="通过Header传递API Key") | |
| api_key_query = APIKeyQuery(name="api_key", auto_error=False, description="通过URL参数传递API Key,如 ?api_key=xxx") | |
| def get_api_key( | |
| header_key: Optional[str] = Depends(api_key_header), | |
| query_key: Optional[str] = Depends(api_key_query) | |
| ) -> str: | |
| """验证API Key,优先取Header中的值,其次取Query中的值""" | |
| if header_key == API_KEY: | |
| return header_key | |
| elif query_key == API_KEY: | |
| return query_key | |
| raise HTTPException( | |
| status_code=401, | |
| detail="无效或缺失API Key(支持Header: X-API-Key 或 Query: ?api_key=xxx)", | |
| headers={"WWW-Authenticate": "X-API-Key"} | |
| ) | |
| # 3. 定义请求/响应数据模型(标准化格式) | |
| class RerankRequest(BaseModel): | |
| """重排序请求模型""" | |
| query: str # 用户查询(如“什么是机器学习?”) | |
| documents: List[str] # 候选文档列表(需排序的文本集合) | |
| top_k: Optional[int] = 5 # 需返回的Top N高相关文档,默认5 | |
| truncation: Optional[bool] = True # 是否截断过长文本,默认True | |
| class DocumentScore(BaseModel): | |
| """单篇文档的排序结果(含分数)""" | |
| document: str # 文档内容 | |
| score: float # 相关性分数(越高越相关) | |
| rank: int # 排序名次(1为最高) | |
| class RerankResponse(BaseModel): | |
| """重排序响应模型""" | |
| request_id: str # 请求唯一标识(便于排查问题) | |
| query: str # 回显请求的查询 | |
| top_k: int # 回显请求的Top K | |
| results: List[DocumentScore] # 排序结果列表 | |
| model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名称 | |
| timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 用标准库生成时间戳 | |
| # 4. 加载Cross-Encoder模型(全局初始化,避免重复加载) | |
| class CrossEncoderLoader: | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # 自动使用GPU(若有),否则用CPU | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self.model.eval() # 推理模式,关闭Dropout | |
| print(f"模型加载完成,使用设备:{self.device}") | |
| def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]: | |
| """ | |
| 核心重排序逻辑 | |
| :param query: 用户查询 | |
| :param documents: 候选文档列表 | |
| :param top_k: 返回Top N | |
| :param truncation: 是否截断文本 | |
| :return: 排序后的DocumentScore列表 | |
| """ | |
| if not documents: | |
| raise ValueError("候选文档列表不能为空") | |
| if top_k <= 0: | |
| raise ValueError("top_k必须为正整数") | |
| # 计算每篇文档的相关性分数 | |
| doc_scores = [] | |
| for doc in documents: | |
| # 模型输入格式:query [SEP] document(SEP为模型默认分隔符) | |
| inputs = self.tokenizer( | |
| text=f"{query} {self.tokenizer.sep_token} {doc}", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=truncation, | |
| max_length=512 # 模型最大输入长度,MiniLM-L-6-v2支持512 | |
| ).to(self.device) | |
| # 推理(关闭梯度计算,提升速度) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # 模型输出的logits即为相关性分数(无需softmax,直接用原始值) | |
| score = outputs.logits.item() | |
| doc_scores.append((doc, score)) | |
| # 按分数降序排序,取Top K,并添加名次 | |
| sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k] | |
| results = [ | |
| DocumentScore( | |
| document=doc, | |
| score=round(score, 4), # 分数保留4位小数,便于阅读 | |
| rank=i+1 # 名次从1开始 | |
| ) for i, (doc, score) in enumerate(sorted_docs) | |
| ] | |
| return results | |
| # 初始化模型(全局唯一实例) | |
| reranker = CrossEncoderLoader() | |
| # 5. 定义API端点(标准POST接口) | |
| async def rerank_endpoint( | |
| request: RerankRequest, | |
| api_key: str = Depends(get_api_key) # 强制API Key认证 | |
| ) -> RerankResponse: | |
| try: | |
| # 生成请求唯一标识(用UUID,需安装:pip install python-uuid) | |
| import uuid | |
| request_id = str(uuid.uuid4()) | |
| # 调用重排序逻辑 | |
| results = reranker.rerank( | |
| query=request.query, | |
| documents=request.documents, | |
| top_k=request.top_k, | |
| truncation=request.truncation | |
| ) | |
| # 构造响应 | |
| return RerankResponse( | |
| request_id=request_id, | |
| query=request.query, | |
| top_k=request.top_k, | |
| results=results | |
| ) | |
| except ValueError as e: | |
| # 业务逻辑错误(如参数无效) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| # 服务器内部错误(如模型加载失败) | |
| raise HTTPException(status_code=500, detail=f"服务器内部错误:{str(e)}") | |
| # 6. 健康检查接口(用于监控服务状态) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model": "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| "device": reranker.device, | |
| "timestamp": str(pd.Timestamp.now()) | |
| } | |
| # ------------------- 新增:根路径(/)首页路由 ------------------- | |
| async def home_page(): | |
| """根路径首页:展示 API 功能、调用示例、认证方式等""" | |
| return f""" | |
| # Cross-Encoder 重排序 API(兼容 GPT 格式) | |
| 基于 `cross-encoder/ms-marco-MiniLM-L-6-v2` 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。 | |
| ## 核心功能 | |
| - 输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名) | |
| - 兼容 OpenAI 风格 API 格式,可直接用 OpenAI 库调用 | |
| - 支持 API Key 认证,保障接口安全""" | |
| # 7. 本地运行入口(开发环境用) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # 安装uvicorn:pip install uvicorn | |
| uvicorn.run( | |
| app="app:app", | |
| host="0.0.0.0", # 允许外部访问 | |
| port=7860, # 端口号 | |
| reload=False # 生产环境关闭reload | |
| ) |