fiewolf1000 commited on
Commit
c6533bc
·
verified ·
1 Parent(s): d7a364f

Upload 3 files

Browse files
Files changed (3) hide show
  1. Procfile +1 -0
  2. app.py +101 -49
  3. requirements.txt +2 -1
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: gunicorn app:app
app.py CHANGED
@@ -1,56 +1,108 @@
 
1
  from sentence_transformers import SentenceTransformer
2
- import gradio as gr
3
  import numpy as np
 
 
4
 
5
- # 加载模型(首次运行会自动下载,约 500MB)
6
- model = SentenceTransformer('BAAI/bge-small-en-v1.5')
7
-
8
- def get_embedding(text: str) -> list:
9
- """生成文本嵌入向量"""
10
- if not text.strip():
11
- return "请输入非空文本"
12
- # 生成嵌入(返回 numpy 数组)
13
- embedding = model.encode(text, normalize_embeddings=True)
14
- # 转换为列表返回(方便 API 传输)
15
- return embedding.tolist()
16
-
17
- def similarity_score(text1: str, text2: str) -> float:
18
- """计算两个文本的余弦相似度"""
19
- if not text1.strip() or not text2.strip():
20
- return 0.0
21
- emb1 = model.encode(text1, normalize_embeddings=True)
22
- emb2 = model.encode(text2, normalize_embeddings=True)
23
- # 余弦相似度 = 向量点积(已归一化)
24
- return float(np.dot(emb1, emb2))
25
-
26
- # 创建 Gradio 界面
27
- with gr.Blocks(title="开源文本嵌入 API") as demo:
28
- gr.Markdown("# 文本嵌入服务(基于 BAAI/bge-small-en-v1.5)")
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- with gr.Tab("生成嵌入向量"):
31
- input_text = gr.Textbox(label="输入文本", placeholder="请输入需要生成嵌入的文本...")
32
- embedding_output = gr.Textbox(label="嵌入向量(前10位)")
33
- generate_btn = gr.Button("生成嵌入")
34
- generate_btn.click(
35
- fn=lambda x: str(get_embedding(x)[:10]) + "...", # 只显示前10位
36
- inputs=input_text,
37
- outputs=embedding_output
38
- )
 
 
39
 
40
- with gr.Tab("计算语义相似度"):
41
- text1 = gr.Textbox(label="文本1", placeholder="输入第一个文本...")
42
- text2 = gr.Textbox(label="文本2", placeholder="输入第二个文本...")
43
- similarity_output = gr.Number(label="余弦相似度(0~1,越高越相似)")
44
- similarity_btn = gr.Button("计算相似度")
45
- similarity_btn.click(
46
- fn=similarity_score,
47
- inputs=[text1, text2],
48
- outputs=similarity_output
49
- )
 
50
 
51
- # 启用队列,支持并发请求
52
- demo.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- if __name__ == "__main__":
55
- # 部署到 Hugging Face Spaces 时,不需要指定 server_name 和 server_port
56
- demo.launch()
 
 
1
+ from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer
 
3
  import numpy as np
4
+ import os
5
+ import time
6
 
7
+ app = Flask(__name__)
8
+
9
+ # 加载模型
10
+ model_name = "BAAI/bge-small-en-v1.5"
11
+ model = SentenceTransformer(model_name)
12
+
13
+ # 支持的模型列表
14
+ SUPPORTED_MODELS = {
15
+ "text-embedding-3-small": model,
16
+ "bge-small-en-v1.5": model
17
+ }
18
+
19
+ # 简单的API密钥验证(可选)
20
+ API_KEY = os.getenv("API_KEY", "your-default-api-key")
21
+
22
+ def verify_api_key(headers):
23
+ """验证API密钥"""
24
+ auth_header = headers.get("Authorization")
25
+ if not auth_header or not auth_header.startswith("Bearer "):
26
+ return False
27
+ return auth_header.split("Bearer ")[1] == API_KEY
28
+
29
+ @app.route('/v1/embeddings', methods=['POST'])
30
+ def create_embedding():
31
+ """生成嵌入向量,兼容OpenAI API格式"""
32
+ # 验证API密钥
33
+ if not verify_api_key(request.headers):
34
+ return jsonify({
35
+ "error": {
36
+ "message": "Invalid API key",
37
+ "type": "invalid_request_error",
38
+ "param": None,
39
+ "code": "invalid_api_key"
40
+ }
41
+ }), 401
42
 
43
+ # 解析请求
44
+ data = request.json
45
+ if not data or "input" not in data:
46
+ return jsonify({
47
+ "error": {
48
+ "message": "Missing input",
49
+ "type": "invalid_request_error",
50
+ "param": None,
51
+ "code": "missing_input"
52
+ }
53
+ }), 400
54
 
55
+ # 获取模型(默认为text-embedding-3-small)
56
+ model_name = data.get("model", "text-embedding-3-small")
57
+ if model_name not in SUPPORTED_MODELS:
58
+ return jsonify({
59
+ "error": {
60
+ "message": f"Model {model_name} not found",
61
+ "type": "invalid_request_error",
62
+ "param": None,
63
+ "code": "model_not_found"
64
+ }
65
+ }), 404
66
 
67
+ # 处理输入(支持单文本或文本列表)
68
+ inputs = data["input"]
69
+ if isinstance(inputs, str):
70
+ inputs = [inputs]
71
+
72
+ # 计算嵌入向量
73
+ start_time = time.time()
74
+ embeddings = model.encode(inputs, normalize_embeddings=True)
75
+ processing_time = time.time() - start_time
76
+
77
+ # 准备响应数据
78
+ response_data = {
79
+ "object": "list",
80
+ "data": [
81
+ {
82
+ "object": "embedding",
83
+ "embedding": embedding.tolist(),
84
+ "index": i
85
+ } for i, embedding in enumerate(embeddings)
86
+ ],
87
+ "model": model_name,
88
+ "usage": {
89
+ "prompt_tokens": sum(len(text.split()) for text in inputs), # 简单估算
90
+ "total_tokens": sum(len(text.split()) for text in inputs)
91
+ }
92
+ }
93
+
94
+ return jsonify(response_data)
95
+
96
+ @app.route('/health', methods=['GET'])
97
+ def health_check():
98
+ """健康检查接口"""
99
+ return jsonify({
100
+ "status": "healthy",
101
+ "model": model_name,
102
+ "supported_models": list(SUPPORTED_MODELS.keys())
103
+ })
104
 
105
+ if __name__ == '__main__':
106
+ # 生产环境应使用Gunicorn等WSGI服务器
107
+ app.run(host='0.0.0.0', port=int(os.getenv('PORT', 7860)))
108
+
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio==4.28.3
2
  sentence-transformers==2.7.0
3
  torch==2.2.2
4
  numpy==1.26.4
 
 
1
+ flask==2.3.3
2
  sentence-transformers==2.7.0
3
  torch==2.2.2
4
  numpy==1.26.4
5
+ gunicorn==21.2.0 # 用于生产环境部署