Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import logging | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Depends, Request | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| from typing import List, Optional | |
| # ------------------- 1. 日志配置 ------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger = logging.getLogger("cross-encoder-api") | |
| # ------------------- 2. 基础配置(缓存 + 环境变量) ------------------- | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache" | |
| # 从环境变量获取 API Key(OpenAI 风格) | |
| API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not API_KEY: | |
| logger.error("环境变量 OPENAI_API_KEY 未设置") | |
| raise ValueError("请设置环境变量 OPENAI_API_KEY") | |
| logger.info("API Key 加载成功") | |
| # ------------------- 3. 初始化 FastAPI 应用 ------------------- | |
| app = FastAPI( | |
| title="OpenAI 兼容的 Cross-Encoder 重排序 API", | |
| description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口", | |
| version="1.0.0" | |
| ) | |
| # ------------------- 4. OpenAI 风格认证(Bearer Token) ------------------- | |
| oauth2_scheme = HTTPBearer(auto_error=False) | |
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme)): | |
| """验证 API Key:必须通过 Authorization: Bearer YOUR_API_KEY 传递""" | |
| request_id = str(uuid.uuid4())[:8] # 生成短请求ID用于日志追踪 | |
| if not credentials: | |
| logger.warning(f"请求 {request_id}:缺少认证信息") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="缺少认证信息(请使用 'Authorization: Bearer YOUR_API_KEY')", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| if credentials.scheme != "Bearer": | |
| logger.warning(f"请求 {request_id}:认证方案错误,应为 Bearer,实际为 {credentials.scheme}") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="认证方案错误(请使用 'Bearer' 方案)", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| if credentials.credentials != API_KEY: | |
| logger.warning(f"请求 {request_id}:无效的 API Key") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="无效的 API Key", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| logger.info(f"请求 {request_id}:API Key 验证通过") | |
| return (credentials.credentials, request_id) # 返回API Key和请求ID | |
| # ------------------- 5. 数据模型定义 ------------------- | |
| class RerankRequest(BaseModel): | |
| query: str | |
| documents: List[str] | |
| top_k: Optional[int] = 3 | |
| truncation: Optional[bool] = True | |
| class DocumentScore(BaseModel): | |
| document: str | |
| relevance_score: float | |
| index: int | |
| class RerankResponse(BaseModel): | |
| request_id: str | |
| query: str | |
| top_k: int | |
| results: List[DocumentScore] | |
| model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] | |
| # GPT 兼容的请求/响应模型 | |
| class GPTMessage(BaseModel): | |
| role: str | |
| content: str | |
| class GPTRequest(BaseModel): | |
| model: str | |
| messages: List[GPTMessage] | |
| top_k: Optional[int] = 3 | |
| class Choice(BaseModel): | |
| index: int | |
| message: GPTMessage | |
| finish_reason: str = "stop" | |
| class GPTResponse(BaseModel): | |
| id: str = f"chatcmpl-{uuid.uuid4().hex}" | |
| object: str = "chat.completion" | |
| created: int = int(datetime.now().timestamp()) | |
| model: str | |
| choices: List[Choice] | |
| usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
| # ------------------- 6. 加载 Cross-Encoder 模型 ------------------- | |
| class CrossEncoderModel: | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.model_name = model_name | |
| logger.info(f"开始加载模型:{model_name}") | |
| # 验证缓存目录可写 | |
| cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache") | |
| try: | |
| test_file = os.path.join(cache_dir, "test.txt") | |
| with open(test_file, "w") as f: | |
| f.write("test") | |
| os.remove(test_file) | |
| logger.info(f"缓存目录可写:{cache_dir}") | |
| except Exception as e: | |
| logger.error(f"缓存目录不可写:{str(e)}") | |
| raise RuntimeError(f"缓存目录不可写:{str(e)}") | |
| # 加载模型 | |
| try: | |
| logger.info("开始加载分词器...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
| logger.info("分词器加载完成") | |
| logger.info("开始加载模型权重...") | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir) | |
| logger.info("模型权重加载完成") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"模型加载完成,使用设备:{self.device}") | |
| except Exception as e: | |
| logger.error(f"模型加载失败:{str(e)}") | |
| raise | |
| def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool, request_id: str) -> List[DocumentScore]: | |
| """核心重排序逻辑,增加详细日志""" | |
| logger.info(f"请求 {request_id}:开始重排序处理,查询长度: {len(query)}, 文档数量: {len(documents)}, top_k: {top_k}") | |
| # 参数校验 | |
| if not documents: | |
| logger.warning(f"请求 {request_id}:候选文档列表为空") | |
| raise ValueError("候选文档不能为空") | |
| if top_k <= 0: | |
| logger.warning(f"请求 {request_id}:无效的 top_k 值: {top_k}") | |
| raise ValueError("top_k 必须为正整数") | |
| # 自动将 top_k 限制为文档数量(避免超出) | |
| adjusted_top_k = min(top_k, len(documents)) | |
| if adjusted_top_k != top_k: | |
| logger.info(f"请求 {request_id}:top_k 从 {top_k} 调整为 {adjusted_top_k}(文档数量限制)") | |
| # 计算每篇文档的相关性分数 | |
| doc_scores = [] | |
| try: | |
| for i, doc in enumerate(documents): | |
| if i % 5 == 0: # 每处理5个文档输出一次日志 | |
| logger.info(f"请求 {request_id}:正在处理第 {i+1}/{len(documents)} 个文档") | |
| inputs = self.tokenizer( | |
| f"{query} {self.tokenizer.sep_token} {doc}", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=truncation, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| score = outputs.logits.item() | |
| doc_scores.append((doc, score)) | |
| logger.debug(f"请求 {request_id}:文档 {i+1} 分数: {score:.4f}") | |
| # 排序并返回结果 | |
| sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:adjusted_top_k] | |
| logger.info(f"请求 {request_id}:重排序完成,返回 {len(sorted_docs)} 个结果") | |
| return [ | |
| DocumentScore(document=doc, relevance_score=round(score, 4), index=i) | |
| for i, (doc, score) in enumerate(sorted_docs) | |
| ] | |
| except Exception as e: | |
| logger.error(f"请求 {request_id}:重排序过程出错: {str(e)}") | |
| raise | |
| # 初始化模型(全局唯一) | |
| try: | |
| reranker = CrossEncoderModel() | |
| except Exception as e: | |
| logger.critical(f"模型初始化失败,服务无法启动: {str(e)}") | |
| raise | |
| # ------------------- 7. API 端点(OpenAI 风格路径) ------------------- | |
| # 7.1 根路径首页 | |
| async def home_page(request: Request): | |
| client_ip = request.client.host | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| logger.info(f"首页访问来自 {client_ip}") | |
| return f""" | |
| <!DOCTYPE html> | |
| <html lang="zh-CN"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>OpenAI 兼容重排序 API</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }} | |
| h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }} | |
| h2 {{ color: #34495e; margin-top: 30px; }} | |
| pre {{ background: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #e9ecef; overflow-x: auto; }} | |
| table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }} | |
| th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }} | |
| th {{ background-color: #f1f5f9; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>OpenAI 兼容的 Cross-Encoder 重排序 API</h1> | |
| <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,支持 OpenAI 风格 API 调用。</p> | |
| <h2>接口列表</h2> | |
| <table> | |
| <tr> | |
| <th>接口</th> | |
| <th>URL</th> | |
| <th>方法</th> | |
| <th>认证</th> | |
| </tr> | |
| <tr> | |
| <td>基础重排序</td> | |
| <td class="api-url">/v1/rerank</td> | |
| <td>POST</td> | |
| <td>Authorization: Bearer API_KEY</td> | |
| </tr> | |
| <tr> | |
| <td>GPT 兼容重排序</td> | |
| <td class="api-url">/v1/chat/completions</td> | |
| <td>POST</td> | |
| <td>Authorization: Bearer API_KEY</td> | |
| </tr> | |
| <tr> | |
| <td>健康检查</td> | |
| <td class="api-url">/v1/health</td> | |
| <td>GET</td> | |
| <td>无需认证</td> | |
| </tr> | |
| </table> | |
| <h2>调用示例(Python)</h2> | |
| <pre><code>import openai | |
| client = openai.OpenAI( | |
| api_key="YOUR_API_KEY", | |
| base_url="https://your-space.hf.space/v1" # 替换为你的 Space URL | |
| ) | |
| response = client.chat.completions.create( | |
| model="cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| messages=[ | |
| {{ | |
| "role": "user", | |
| "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言;" | |
| }} | |
| ], | |
| top_k=2 | |
| ) | |
| print(response.choices[0].message.content)</code></pre> | |
| </body> | |
| </html> | |
| """ | |
| # 7.2 基础重排序接口(/v1/rerank) | |
| async def base_rerank( | |
| request: RerankRequest, | |
| auth_result: tuple = Depends(verify_api_key) | |
| ): | |
| api_key, request_id = auth_result | |
| try: | |
| logger.info(f"请求 {request_id}:收到 /v1/rerank 请求,query: {request.query[:50]}...(截断显示)") | |
| # 执行重排序 | |
| results = reranker.rerank( | |
| query=request.query, | |
| documents=request.documents, | |
| top_k=request.top_k, | |
| truncation=request.truncation, | |
| request_id=request_id | |
| ) | |
| # 构建响应 | |
| response = RerankResponse( | |
| request_id=request_id, | |
| query=request.query, | |
| top_k=min(request.top_k, len(request.documents)), | |
| results=results | |
| ) | |
| logger.info(f"请求 {request_id}:处理完成,返回 {len(results)} 个结果") | |
| return response | |
| except ValueError as e: | |
| logger.warning(f"请求 {request_id}:参数错误 - {str(e)}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"请求 {request_id}:服务器错误 - {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 7.3 GPT 兼容接口(/v1/chat/completions) | |
| async def gpt_compatible_rerank( | |
| request: GPTRequest, | |
| auth_result: tuple = Depends(verify_api_key) | |
| ): | |
| api_key, request_id = auth_result | |
| try: | |
| logger.info(f"请求 {request_id}:收到 /v1/chat/completions 请求,模型: {request.model}") | |
| # 验证模型名 | |
| if request.model != reranker.model_name: | |
| error_msg = f"仅支持模型:{reranker.model_name},实际请求:{request.model}" | |
| logger.warning(f"请求 {request_id}:{error_msg}") | |
| raise ValueError(error_msg) | |
| # 验证消息格式 | |
| if not request.messages: | |
| logger.warning(f"请求 {request_id}:消息列表为空") | |
| raise ValueError("消息列表不能为空") | |
| if request.messages[-1].role != "user": | |
| error_msg = f"最后一条消息必须是 'user' 角色,实际为:{request.messages[-1].role}" | |
| logger.warning(f"请求 {request_id}:{error_msg}") | |
| raise ValueError(error_msg) | |
| # 解析输入内容 | |
| content = request.messages[-1].content | |
| logger.info(f"请求 {request_id}:用户输入: {content[:100]}...(截断显示)") | |
| if "; documents: " not in content: | |
| error_msg = "输入格式需为 'query: [查询]; documents: [文档1]; [文档2]; ...'" | |
| logger.warning(f"请求 {request_id}:{error_msg}") | |
| raise ValueError(error_msg) | |
| query_part, docs_part = content.split("; documents: ") | |
| query = query_part.replace("query: ", "").strip() | |
| documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()] | |
| logger.info(f"请求 {request_id}:解析完成,query: {query[:50]}..., 文档数量: {len(documents)}") | |
| # 执行重排序 | |
| results = reranker.rerank( | |
| query=query, | |
| documents=documents, | |
| top_k=request.top_k, | |
| truncation=True, | |
| request_id=request_id | |
| ) | |
| # 构建 GPT 风格响应 | |
| response = GPTResponse( | |
| model=request.model, | |
| choices=[ | |
| Choice( | |
| index=0, | |
| message=GPTMessage( | |
| role="assistant", | |
| content=f"重排序结果:{results}" | |
| ) | |
| ) | |
| ] | |
| ) | |
| logger.info(f"请求 {request_id}:处理完成,返回 {len(results)} 个结果") | |
| return response | |
| except ValueError as e: | |
| logger.warning(f"请求 {request_id}:参数错误 - {str(e)}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"请求 {request_id}:服务器错误 - {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 7.4 健康检查接口(/v1/health) | |
| async def health_check(request: Request): | |
| client_ip = request.client.host | |
| status = { | |
| "status": "healthy", | |
| "model": reranker.model_name, | |
| "device": reranker.device, | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "uptime": datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 简化版uptime | |
| } | |
| logger.info(f"健康检查来自 {client_ip}:{status['status']}") | |
| return status | |
| # ------------------- 8. 本地运行入口 ------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("启动本地开发服务器...") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_config=None # 使用自定义日志配置 | |
| ) | |