Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Depends, Request | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| from typing import List, Optional | |
| # ------------------- 1. 基础配置(缓存 + 环境变量) ------------------- | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache" | |
| # 从环境变量获取 API Key(OpenAI 风格) | |
| API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not API_KEY: | |
| raise ValueError("请设置环境变量 OPENAI_API_KEY") | |
| # ------------------- 2. 初始化 FastAPI 应用 ------------------- | |
| app = FastAPI( | |
| title="OpenAI 兼容的 Cross-Encoder 重排序 API", | |
| description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口", | |
| version="1.0.0" | |
| ) | |
| # ------------------- 3. OpenAI 风格认证(Bearer Token) ------------------- | |
| oauth2_scheme = HTTPBearer(auto_error=False) | |
| def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme)): | |
| """验证 API Key:必须通过 Authorization: Bearer YOUR_API_KEY 传递""" | |
| if not credentials or credentials.scheme != "Bearer" or credentials.credentials != API_KEY: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="无效的 API Key(请使用 'Authorization: Bearer YOUR_API_KEY')", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| return credentials.credentials | |
| # ------------------- 4. 数据模型定义 ------------------- | |
| class RerankRequest(BaseModel): | |
| query: str | |
| documents: List[str] | |
| top_k: Optional[int] = 3 | |
| truncation: Optional[bool] = True | |
| class DocumentScore(BaseModel): | |
| document: str | |
| score: float | |
| rank: int | |
| class RerankResponse(BaseModel): | |
| request_id: str | |
| query: str | |
| top_k: int | |
| results: List[DocumentScore] | |
| model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] | |
| # GPT 兼容的请求/响应模型 | |
| class GPTMessage(BaseModel): | |
| role: str | |
| content: str | |
| class GPTRequest(BaseModel): | |
| model: str | |
| messages: List[GPTMessage] | |
| top_k: Optional[int] = 3 | |
| class Choice(BaseModel): | |
| index: int | |
| message: GPTMessage | |
| finish_reason: str = "stop" | |
| class GPTResponse(BaseModel): | |
| id: str = f"chatcmpl-{uuid.uuid4().hex}" | |
| object: str = "chat.completion" | |
| created: int = int(datetime.now().timestamp()) | |
| model: str | |
| choices: List[Choice] | |
| usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
| # ------------------- 5. 加载 Cross-Encoder 模型 ------------------- | |
| class CrossEncoderModel: | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| self.model_name = model_name | |
| # 验证缓存目录可写 | |
| cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache") | |
| try: | |
| test_file = os.path.join(cache_dir, "test.txt") | |
| with open(test_file, "w") as f: | |
| f.write("test") | |
| os.remove(test_file) | |
| print(f"缓存目录可写:{cache_dir}") | |
| except Exception as e: | |
| raise RuntimeError(f"缓存目录不可写:{str(e)}") | |
| # 加载模型 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print(f"模型加载完成,设备:{self.device}") | |
| def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]: | |
| if not documents: | |
| raise ValueError("候选文档不能为空") | |
| if top_k <= 0 or top_k > len(documents): | |
| raise ValueError(f"top_k 需在 1~{len(documents)} 之间") | |
| doc_scores = [] | |
| for doc in documents: | |
| inputs = self.tokenizer( | |
| f"{query} {self.tokenizer.sep_token} {doc}", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=truncation, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| score = outputs.logits.item() | |
| doc_scores.append((doc, score)) | |
| sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k] | |
| return [ | |
| DocumentScore(document=doc, score=round(score, 4), rank=i+1) | |
| for i, (doc, score) in enumerate(sorted_docs) | |
| ] | |
| reranker = CrossEncoderModel() | |
| # ------------------- 6. API 端点(OpenAI 风格路径) ------------------- | |
| # 6.1 根路径首页 | |
| async def home_page(): | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| return f""" | |
| <!DOCTYPE html> | |
| <html lang="zh-CN"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>OpenAI 兼容重排序 API</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }} | |
| h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }} | |
| h2 {{ color: #34495e; margin-top: 30px; }} | |
| pre {{ background: #f8f9fa; padding: 15px; border-radius: 5px; border: 1px solid #e9ecef; overflow-x: auto; }} | |
| table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }} | |
| th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }} | |
| th {{ background-color: #f1f5f9; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>OpenAI 兼容的 Cross-Encoder 重排序 API</h1> | |
| <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,支持 OpenAI 风格 API 调用。</p> | |
| <h2>接口列表</h2> | |
| <table> | |
| <tr> | |
| <th>接口</th> | |
| <th>URL</th> | |
| <th>方法</th> | |
| <th>认证</th> | |
| </tr> | |
| <tr> | |
| <td>基础重排序</td> | |
| <td class="api-url">/v1/rerank</td> | |
| <td>POST</td> | |
| <td>Authorization: Bearer API_KEY</td> | |
| </tr> | |
| <tr> | |
| <td>GPT 兼容重排序</td> | |
| <td class="api-url">/v1/chat/completions</td> | |
| <td>POST</td> | |
| <td>Authorization: Bearer API_KEY</td> | |
| </tr> | |
| <tr> | |
| <td>健康检查</td> | |
| <td class="api-url">/v1/health</td> | |
| <td>GET</td> | |
| <td>无需认证</td> | |
| </tr> | |
| </table> | |
| <h2>调用示例(Python)</h2> | |
| <pre><code>import openai | |
| client = openai.OpenAI( | |
| api_key="YOUR_API_KEY", | |
| base_url="https://your-space.hf.space/v1" # 替换为你的 Space URL | |
| ) | |
| response = client.chat.completions.create( | |
| model="cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| messages=[ | |
| {{ | |
| "role": "user", | |
| "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言;" | |
| }} | |
| ], | |
| top_k=2 | |
| ) | |
| print(response.choices[0].message.content)</code></pre> | |
| </body> | |
| </html> | |
| """ | |
| # 6.2 基础重排序接口(/v1/rerank) | |
| async def base_rerank( | |
| request: RerankRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| results = reranker.rerank( | |
| query=request.query, | |
| documents=request.documents, | |
| top_k=request.top_k, | |
| truncation=request.truncation | |
| ) | |
| return RerankResponse( | |
| request_id=str(uuid.uuid4()), | |
| query=request.query, | |
| top_k=request.top_k, | |
| results=results | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 6.3 GPT 兼容接口(/v1/chat/completions) | |
| async def gpt_compatible_rerank( | |
| request: GPTRequest, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| try: | |
| if request.model != reranker.model_name: | |
| raise ValueError(f"仅支持模型:{reranker.model_name}") | |
| if not request.messages or request.messages[-1].role != "user": | |
| raise ValueError("最后一条消息必须是 'user' 角色") | |
| content = request.messages[-1].content | |
| if "; documents: " not in content: | |
| raise ValueError("输入格式需为 'query: [查询]; documents: [文档1]; [文档2]; ...'") | |
| query_part, docs_part = content.split("; documents: ") | |
| query = query_part.replace("query: ", "").strip() | |
| documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()] | |
| results = reranker.rerank( | |
| query=query, | |
| documents=documents, | |
| top_k=request.top_k, | |
| truncation=True | |
| ) | |
| return GPTResponse( | |
| model=request.model, | |
| choices=[ | |
| Choice( | |
| index=0, | |
| message=GPTMessage( | |
| role="assistant", | |
| content=f"重排序结果:{results}" | |
| ) | |
| ) | |
| ] | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") | |
| # 6.4 健康检查接口(/v1/health) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model": reranker.model_name, | |
| "device": reranker.device, | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| # ------------------- 7. 本地运行入口 ------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |