Spaces:
Runtime error
Runtime error
| import os | |
| import uuid | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Depends | |
| from fastapi.security import APIKeyHeader, APIKeyQuery | |
| from fastapi.responses import HTMLResponse # 仅保留 HTMLResponse,删除 MarkdownResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| from typing import List, Optional | |
| # ------------------- 1. 基础配置(缓存目录 + 环境变量) ------------------- | |
| # 设置 Hugging Face 缓存目录(可写目录,解决权限问题) | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache" | |
| # 从环境变量获取 API Key(认证用,需在 Hugging Face Spaces 中配置) | |
| API_KEY = os.getenv("CROSS_ENCODER_API_KEY") | |
| if not API_KEY: | |
| raise ValueError("请在 Hugging Face Spaces 中设置环境变量 CROSS_ENCODER_API_KEY") | |
| # ------------------- 2. 初始化 FastAPI 应用(仅初始化一次) ------------------- | |
| app = FastAPI( | |
| title="Cross-Encoder 重排序 API", | |
| description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口(兼容 GPT 格式)", | |
| version="1.0.0" | |
| ) | |
| # ------------------- 3. API Key 认证配置(支持 Header/Query 两种方式) ------------------- | |
| # 支持 Header(推荐)和 Query(备用)传递 API Key | |
| api_key_header = APIKeyHeader( | |
| name="X-API-Key", | |
| auto_error=False, | |
| description="通过 Header 传递 API Key(推荐)" | |
| ) | |
| api_key_query = APIKeyQuery( | |
| name="api_key", | |
| auto_error=False, | |
| description="通过 URL 参数传递 API Key(如 ?api_key=xxx)" | |
| ) | |
| def verify_api_key( | |
| header_key: Optional[str] = Depends(api_key_header), | |
| query_key: Optional[str] = Depends(api_key_query) | |
| ) -> str: | |
| """验证 API Key,优先使用 Header 中的值""" | |
| if header_key == API_KEY: | |
| return header_key | |
| elif query_key == API_KEY: | |
| return query_key | |
| raise HTTPException( | |
| status_code=401, | |
| detail="无效或缺失 API Key(支持:Header: X-API-Key 或 Query: ?api_key=xxx)", | |
| headers={"WWW-Authenticate": "X-API-Key"} | |
| ) | |
| # ------------------- 4. 数据模型定义(请求/响应格式) ------------------- | |
| class RerankRequest(BaseModel): | |
| """重排序请求模型(支持基础重排序 + GPT 兼容格式)""" | |
| query: str # 用户查询(如“什么是机器学习?”) | |
| documents: List[str] # 候选文档列表(需排序的文本) | |
| top_k: Optional[int] = 3 # 返回 Top N 高相关文档,默认 3 | |
| truncation: Optional[bool] = True # 是否截断过长文本(模型最大输入 512 Token) | |
| class DocumentScore(BaseModel): | |
| """单篇文档的排序结果(含分数和排名)""" | |
| document: str # 文档内容 | |
| score: float # 相关性分数(越高越相关) | |
| rank: int # 排序名次(1 为最高) | |
| class RerankResponse(BaseModel): | |
| """重排序响应模型(标准化格式)""" | |
| request_id: str # 请求唯一标识(用于排查问题) | |
| query: str # 回显用户查询 | |
| top_k: int # 回显返回的 Top N 数量 | |
| results: List[DocumentScore] # 排序结果列表 | |
| model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名 | |
| timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 时间戳 | |
| # GPT 兼容格式的请求模型(适配 /v1/chat/completions 接口) | |
| class GPTMessage(BaseModel): | |
| role: str # 仅支持 "user" 角色 | |
| content: str # 格式:"query: [查询]; documents: [文档1]; [文档2]; ..." | |
| class GPTRequest(BaseModel): | |
| model: str # 固定为模型名,用于兼容 GPT 调用格式 | |
| messages: List[GPTMessage] # GPT 风格的消息列表 | |
| top_k: Optional[int] = 3 # 同 RerankRequest 的 top_k | |
| class GPTResponse(BaseModel): | |
| """GPT 兼容的响应模型(模仿 OpenAI 格式)""" | |
| id: str = f"rerank-{uuid.uuid4().hex[:10]}" | |
| object: str = "chat.completion" | |
| created: int = int(datetime.now().timestamp()) | |
| model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| choices: List[dict] = [] # 存储排序结果 | |
| usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
| # ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) ------------------- | |
| class CrossEncoderModel: | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.model_name = model_name | |
| # 加载分词器和模型(从缓存目录加载,避免权限问题) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # 自动选择设备(GPU 优先,无则用 CPU) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self.model.eval() # 推理模式,关闭 Dropout | |
| print(f"模型加载完成!使用设备:{self.device}") | |
| def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]: | |
| """核心重排序逻辑:计算查询与文档的相关性并排序""" | |
| # 参数校验 | |
| if not documents: | |
| raise ValueError("候选文档列表不能为空") | |
| if top_k <= 0 or top_k > len(documents): | |
| raise ValueError(f"top_k 需在 1~{len(documents)} 之间") | |
| # 计算每篇文档的相关性分数 | |
| doc_scores = [] | |
| for doc in documents: | |
| # 模型输入格式:query [SEP] document(SEP 是模型默认分隔符) | |
| input_text = f"{query} {self.tokenizer.sep_token} {doc}" | |
| inputs = self.tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=truncation, | |
| max_length=512 # 模型最大输入长度(MiniLM-L-6-v2 支持 512 Token) | |
| ).to(self.device) | |
| # 推理(关闭梯度计算,提升速度) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # 模型输出的 logits 即为相关性分数(无需 softmax,直接使用原始值) | |
| score = outputs.logits.item() | |
| doc_scores.append((doc, score)) | |
| # 按分数降序排序,取 Top K 并生成结果 | |
| sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k] | |
| return [ | |
| DocumentScore( | |
| document=doc, | |
| score=round(score, 4), # 分数保留 4 位小数,便于阅读 | |
| rank=i+1 # 名次从 1 开始 | |
| ) for i, (doc, score) in enumerate(sorted_docs) | |
| ] | |
| # 初始化模型(全局唯一,避免重复加载) | |
| reranker = CrossEncoderModel() | |
| # ------------------- 6. API 端点定义 ------------------- | |
| # 6.1 根路径首页(HTML 格式,无 Markdown 依赖) | |
| async def home_page(): | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| return f""" | |
| <!DOCTYPE html> | |
| <html lang="zh-CN"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Cross-Encoder 重排序 API</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }} | |
| h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }} | |
| h2 {{ color: #34495e; margin-top: 30px; }} | |
| pre {{ background: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #e9ecef; overflow-x: auto; }} | |
| table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }} | |
| th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }} | |
| th {{ background-color: #f1f5f9; }} | |
| .note {{ color: #6c757d; font-size: 0.9em; }} | |
| .api-url {{ color: #3498db; font-weight: bold; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Cross-Encoder 重排序 API</h1> | |
| <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。</p> | |
| <h2>核心功能</h2> | |
| <ul> | |
| <li>输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)</li> | |
| <li>支持两种 API 格式:基础重排序接口(/api/v1/rerank)和 GPT 兼容接口(/v1/chat/completions)</li> | |
| <li>API Key 认证,保障接口安全</li> | |
| </ul> | |
| <h2>接口列表</h2> | |
| <table> | |
| <tr> | |
| <th>接口名称</th> | |
| <th>URL</th> | |
| <th>方法</th> | |
| <th>说明</th> | |
| </tr> | |
| <tr> | |
| <td>基础重排序接口</td> | |
| <td class="api-url">{app.root_path}/api/v1/rerank</td> | |
| <td>POST</td> | |
| <td>标准化重排序接口,返回结构化结果</td> | |
| </tr> | |
| <tr> | |
| <td>GPT 兼容接口</td> | |
| <td class="api-url">{app.root_path}/v1/chat/completions</td> | |
| <td>POST</td> | |
| <td>模仿 OpenAI 格式,可直接用 OpenAI 库调用</td> | |
| </tr> | |
| <tr> | |
| <td>健康检查</td> | |
| <td class="api-url">{app.root_path}/api/v1/health</td> | |
| <td>GET</td> | |
| <td>无需认证,检查服务状态</td> | |
| </tr> | |
| </table> | |
| <h2>调用示例(GPT 兼容接口)</h2> | |
| <pre><code>from openai import OpenAI | |
| # 配置客户端(指向你的 Space 地址) | |
| client = OpenAI( | |
| api_key="your-api-key-here", # 替换为你的 API Key | |
| base_url="https://<your-username>-<your-space-name>.hf.space/v1" # 替换为你的 Space URL | |
| ) | |
| # 发送重排序请求 | |
| response = client.chat.completions.create( | |
| model="cross-encoder/ms-marco-MiniLM-L-6-v2", # 固定模型名 | |
| messages=[ | |
| {{ | |
| "role": "user", | |
| "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言; 深度学习是机器学习的子集;" | |
| }} | |
| ], | |
| top_k=2 # 返回 Top 2 高相关文档 | |
| ) | |
| # 打印结果 | |
| print(response.choices[0].message.content)</code></pre> | |
| <h2>API Key 认证方式</h2> | |
| <p>所有 POST 接口需通过以下方式之一传递 API Key:</p> | |
| <ul> | |
| <li><strong>Header 方式(推荐)</strong>:在请求 Header 中添加 <code>X-API-Key: your-api-key</code></li> | |
| <li><strong>Query 方式(备用)</strong>:在 URL 后添加 <code>?api_key=your-api-key</code></li> | |
| </ul> | |
| <p class="note">页面生成时间: {current_time} | 模型运行设备: {reranker.device}</p> | |
| </body> | |
| </html> | |
| """ | |
| # 6.2 基础重排序接口(标准化格式) | |
| async def base_rerank( | |
| request: RerankRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| # 执行重排序 | |
| results = reranker.rerank( | |
| query=request.query, | |
| documents=request.documents, | |
| top_k=request.top_k, | |
| truncation=request.truncation | |
| ) | |
| # 生成响应 | |
| return RerankResponse( | |
| request_id=str(uuid.uuid4()), | |
| query=request.query, | |
| top_k=request.top_k, | |
| results=results | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 6.3 GPT 兼容接口(模仿 OpenAI 格式) | |
| async def gpt_compatible_rerank( | |
| request: GPTRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| # 校验模型名(确保兼容 GPT 调用格式) | |
| if request.model != reranker.model_name: | |
| raise ValueError(f"仅支持模型:{reranker.model_name}") | |
| # 校验消息(仅支持最后一条为 user 角色) | |
| if not request.messages or request.messages[-1].role != "user": | |
| raise ValueError("最后一条消息必须是 'user' 角色") | |
| # 解析用户输入(从 content 中提取 query 和 documents) | |
| content = request.messages[-1].content | |
| if "; documents: " not in content: | |
| raise ValueError("输入格式错误,需为:'query: [查询]; documents: [文档1]; [文档2]; ...'") | |
| query_part, docs_part = content.split("; documents: ") | |
| query = query_part.replace("query: ", "").strip() | |
| documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()] | |
| # 执行重排序 | |
| results = reranker.rerank( | |
| query=query, | |
| documents=documents, | |
| top_k=request.top_k, | |
| truncation=True | |
| ) | |
| # 格式化 GPT 风格的响应 | |
| return GPTResponse( | |
| choices=[{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": f"重排序结果(按相关性降序):\n{[{'文档': r.document, '分数': r.score, '排名': r.rank} for r in results]}" | |
| }, | |
| "finish_reason": "stop" | |
| }] | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 6.4 健康检查接口 | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model": reranker.model_name, | |
| "device": reranker.device, | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "message": "服务正常运行" | |
| } | |
| # ------------------- 7. 本地运行入口(开发环境用) ------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app="app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=False # 生产环境关闭 reload | |
| ) |