Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| import numpy as np | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Any | |
| import time | |
| # 创建 FastAPI 应用 | |
| app = FastAPI() | |
| # 配置 CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 加载模型和分词器 | |
| model_name = "BAAI/bge-m3" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| model.eval() | |
| # OpenAI 兼容的请求模型 | |
| class EmbeddingRequest(BaseModel): | |
| input: List[str] | str | |
| model: str | None = model_name | |
| encoding_format: str | None = "float" | |
| user: str | None = None | |
| # OpenAI 兼容的响应模型 | |
| class EmbeddingResponse(BaseModel): | |
| object: str = "list" | |
| data: List[Dict[str, Any]] | |
| model: str | |
| usage: Dict[str, int] | |
| def get_embedding(text: str) -> List[float]: | |
| inputs = tokenizer( | |
| text, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
| return embeddings[0].tolist() | |
| # OpenAI 兼容的 embeddings endpoint | |
| async def create_embeddings(request: EmbeddingRequest): | |
| start_time = time.time() | |
| # 处理输入 | |
| if isinstance(request.input, str): | |
| input_texts = [request.input] | |
| else: | |
| input_texts = request.input | |
| # 获取嵌入向量 | |
| embeddings = [] | |
| total_tokens = 0 | |
| for text in input_texts: | |
| # 计算 token 数量 | |
| tokens = tokenizer.encode(text) | |
| total_tokens += len(tokens) | |
| # 获取嵌入向量 | |
| embedding = get_embedding(text) | |
| embeddings.append({ | |
| "object": "embedding", | |
| "embedding": embedding, | |
| "index": len(embeddings) | |
| }) | |
| response = EmbeddingResponse( | |
| data=embeddings, | |
| model=request.model or model_name, | |
| usage={ | |
| "prompt_tokens": total_tokens, | |
| "total_tokens": total_tokens | |
| } | |
| ) | |
| return response | |
| # Gradio 界面 | |
| def gradio_embedding(text: str) -> Dict: | |
| # 创建与 OpenAI 兼容的请求 | |
| request = EmbeddingRequest(input=text) | |
| # 调用 API 处理函数 | |
| response = create_embeddings(request) | |
| return response.dict() | |
| # 创建 Gradio 界面 | |
| demo = gr.Interface( | |
| fn=gradio_embedding, | |
| inputs=gr.Textbox(lines=3, placeholder="输入要进行编码的文本..."), | |
| outputs=gr.Json(), | |
| title="BGE-M3 Embeddings (OpenAI 兼容格式)", | |
| description="输入文本,获取其对应的嵌入向量,返回格式与 OpenAI API 兼容。", | |
| examples=[ | |
| ["这是一个示例文本。"], | |
| ["人工智能正在改变世界。"] | |
| ] | |
| ) | |
| # 启动服务 | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # 首先启动 Gradio | |
| demo.queue() | |
| # 然后启动 FastAPI | |
| config = uvicorn.Config( | |
| app=app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| server = uvicorn.Server(config) | |
| server.run() |