fiewolf1000 commited on
Commit
4eb2455
·
verified ·
1 Parent(s): 2a769ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -26
app.py CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
4
  import os
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
- from typing import List, Optional
8
  import logging
9
 
10
  # 配置日志
@@ -12,8 +12,8 @@ logging.basicConfig(
12
  level=logging.INFO,
13
  format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
14
  handlers=[
15
- logging.FileHandler("embedding_service.log"), # 日志写入文件
16
- logging.StreamHandler() # 同时输出到控制台
17
  ]
18
  )
19
  logger = logging.getLogger("embedding_service")
@@ -35,19 +35,15 @@ MODEL_MAPPING = {
35
  "text-embedding-3-large": "BAAI/bge-large-en-v1.5"
36
  }
37
 
38
- # 加载模型(懒加载,首次请求时加载)
39
  models = {}
40
 
41
  def get_model(model_name: str):
42
  logger.info(f"尝试获取模型: {model_name}")
 
43
  if model_name not in models:
44
- if model_name not in MODEL_MAPPING:
45
- error_msg = f"不支持的模型: {model_name}"
46
- logger.error(error_msg)
47
- raise HTTPException(status_code=400, detail=error_msg)
48
- logger.info(f"开始加载模型: {MODEL_MAPPING[model_name]}")
49
  try:
50
- models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
51
  logger.info(f"模型 {model_name} 加载成功")
52
  except Exception as e:
53
  error_msg = f"加载模型 {model_name} 失败: {str(e)}"
@@ -57,21 +53,18 @@ def get_model(model_name: str):
57
 
58
  # 验证API密钥
59
  def verify_api_key(authorization: Optional[str] = Header(None)):
60
- logger.info("执行API密钥验证")
61
  logger.info(f"Authorization头部内容: {authorization}")
62
  if not authorization or not authorization.startswith("Bearer "):
63
- logger.warning("未提供有效的API密钥格式")
64
  raise HTTPException(status_code=401, detail="未提供有效的API密钥")
65
  api_key = authorization[len("Bearer "):]
66
  if api_key != os.getenv("API_KEY"):
67
- logger.warning("无效的API密钥")
68
  raise HTTPException(status_code=401, detail="无效的API密钥")
69
  logger.info("API密钥验证通过")
70
  return True
71
 
72
  # 请求体模型
73
  class EmbeddingRequest(BaseModel):
74
- input: str or List[str]
75
  model: str
76
  encoding_format: Optional[str] = "float"
77
 
@@ -90,9 +83,10 @@ class EmbeddingResponse(BaseModel):
90
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
91
  async def create_embedding(
92
  request: Request,
93
- req: EmbeddingRequest
 
94
  ):
95
- # 先打印完整请求信息(在验证之前)
96
  logger.info("\n===== 接收到的完整请求信息 =====")
97
  logger.info(f"请求方法: {request.method}")
98
  logger.info(f"请求URL: {request.url}")
@@ -102,11 +96,7 @@ async def create_embedding(
102
  logger.info(f"请求体: {await request.body()}")
103
  logger.info("===============================\n")
104
 
105
- # 手动执行验证(在打印日志之后)
106
- authorization = request.headers.get("Authorization")
107
- verify_api_key(authorization)
108
-
109
- # 原有嵌入处理逻辑
110
  logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
111
  try:
112
  model = get_model(req.model)
@@ -131,9 +121,7 @@ async def create_embedding(
131
  usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
132
  )
133
  except Exception as e:
134
- error_msg = f"处理嵌入请求时发生错误: {str(e)}"
135
- logger.error(error_msg)
136
- raise HTTPException(status_code=500, detail=error_msg)
137
 
138
  @app.get("/health")
139
  async def health_check(request: Request):
@@ -144,9 +132,9 @@ async def health_check(request: Request):
144
  for name, value in request.headers.items():
145
  logger.info(f" {name}: {value}")
146
  logger.info("===============================\n")
147
- return {"status": "healthy", "models": list(MODEL_MAPPING.keys())}
148
 
149
  if __name__ == "__main__":
150
  import uvicorn
151
  logger.info("启动服务")
152
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  import os
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
+ from typing import List, Optional, Union # 导入Union
8
  import logging
9
 
10
  # 配置日志
 
12
  level=logging.INFO,
13
  format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
14
  handlers=[
15
+ logging.FileHandler("embedding_service.log"),
16
+ logging.StreamHandler()
17
  ]
18
  )
19
  logger = logging.getLogger("embedding_service")
 
35
  "text-embedding-3-large": "BAAI/bge-large-en-v1.5"
36
  }
37
 
38
+ # 加载模型(懒加载)
39
  models = {}
40
 
41
  def get_model(model_name: str):
42
  logger.info(f"尝试获取模型: {model_name}")
43
+ model_to_load = MODEL_MAPPING.get(model_name, model_name) # 兼容直接用开源模型名
44
  if model_name not in models:
 
 
 
 
 
45
  try:
46
+ models[model_name] = SentenceTransformer(model_to_load)
47
  logger.info(f"模型 {model_name} 加载成功")
48
  except Exception as e:
49
  error_msg = f"加载模型 {model_name} 失败: {str(e)}"
 
53
 
54
  # 验证API密钥
55
  def verify_api_key(authorization: Optional[str] = Header(None)):
 
56
  logger.info(f"Authorization头部内容: {authorization}")
57
  if not authorization or not authorization.startswith("Bearer "):
 
58
  raise HTTPException(status_code=401, detail="未提供有效的API密钥")
59
  api_key = authorization[len("Bearer "):]
60
  if api_key != os.getenv("API_KEY"):
 
61
  raise HTTPException(status_code=401, detail="无效的API密钥")
62
  logger.info("API密钥验证通过")
63
  return True
64
 
65
  # 请求体模型
66
  class EmbeddingRequest(BaseModel):
67
+ input: Union[str, List[str]] # 支持str或List[str]
68
  model: str
69
  encoding_format: Optional[str] = "float"
70
 
 
83
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
84
  async def create_embedding(
85
  request: Request,
86
+ req: EmbeddingRequest,
87
+ _: bool = Depends(verify_api_key)
88
  ):
89
+ # 打印请求信息
90
  logger.info("\n===== 接收到的完整请求信息 =====")
91
  logger.info(f"请求方法: {request.method}")
92
  logger.info(f"请求URL: {request.url}")
 
96
  logger.info(f"请求体: {await request.body()}")
97
  logger.info("===============================\n")
98
 
99
+ # 嵌入生成逻辑
 
 
 
 
100
  logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
101
  try:
102
  model = get_model(req.model)
 
121
  usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
122
  )
123
  except Exception as e:
124
+ raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}")
 
 
125
 
126
  @app.get("/health")
127
  async def health_check(request: Request):
 
132
  for name, value in request.headers.items():
133
  logger.info(f" {name}: {value}")
134
  logger.info("===============================\n")
135
+ return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())}
136
 
137
  if __name__ == "__main__":
138
  import uvicorn
139
  logger.info("启动服务")
140
+ uvicorn.run(app, host="0.0.0.0", port=7860)