fiewolf1000 commited on
Commit
1e5bdfa
·
verified ·
1 Parent(s): 9493f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -5
app.py CHANGED
@@ -5,6 +5,18 @@ import os
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
  from typing import List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
@@ -27,21 +39,35 @@ MODEL_MAPPING = {
27
  models = {}
28
 
29
  def get_model(model_name: str):
 
30
  if model_name not in models:
31
  # 检查是否支持该模型
32
  if model_name not in MODEL_MAPPING:
33
- raise HTTPException(status_code=400, detail=f"不支持的模型: {model_name}")
 
 
34
  # 加载模型
35
- models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
 
 
 
 
 
 
 
36
  return models[model_name]
37
 
38
  # 验证API密钥
39
  def verify_api_key(authorization: Optional[str] = None):
 
40
  if not authorization or not authorization.startswith("Bearer "):
 
41
  raise HTTPException(status_code=401, detail="未提供有效的API密钥")
42
  api_key = authorization[len("Bearer "):]
43
  if api_key != os.getenv("API_KEY"):
 
44
  raise HTTPException(status_code=401, detail="无效的API密钥")
 
45
  return True
46
 
47
  # 请求体模型(对齐OpenAI格式)
@@ -67,15 +93,19 @@ async def create_embedding(
67
  request: EmbeddingRequest,
68
  _: bool = Depends(verify_api_key)
69
  ):
 
70
  try:
71
  # 获取模型
72
  model = get_model(request.model)
73
 
74
  # 处理输入(支持单文本或文本列表)
75
  inputs = [request.input] if isinstance(request.input, str) else request.input
 
76
 
77
  # 计算嵌入
 
78
  embeddings = model.encode(inputs, normalize_embeddings=True)
 
79
 
80
  # 构建响应
81
  data = [
@@ -85,6 +115,7 @@ async def create_embedding(
85
 
86
  # 估算token数(简单近似:每个单词约1 token)
87
  prompt_tokens = sum(len(text.split()) for text in inputs)
 
88
 
89
  return EmbeddingResponse(
90
  data=data,
@@ -92,14 +123,17 @@ async def create_embedding(
92
  usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
93
  )
94
  except Exception as e:
95
- raise HTTPException(status_code=500, detail=str(e))
 
 
96
 
97
  # 健康检查接口
98
  @app.get("/health")
99
  async def health_check():
 
100
  return {"status": "healthy", "models": list(MODEL_MAPPING.keys())}
101
 
102
  if __name__ == "__main__":
103
  import uvicorn
104
- uvicorn.run(app, host="0.0.0.0", port=7860)
105
-
 
5
  import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
  from typing import List, Optional
8
+ import logging
9
+
10
+ # 配置日志
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
14
+ handlers=[
15
+ logging.FileHandler("embedding_service.log"), # 日志写入文件
16
+ logging.StreamHandler() # 同时输出到控制台
17
+ ]
18
+ )
19
+ logger = logging.getLogger("embedding_service")
20
 
21
  app = FastAPI()
22
 
 
39
  models = {}
40
 
41
  def get_model(model_name: str):
42
+ logger.info(f"尝试获取模型: {model_name}")
43
  if model_name not in models:
44
  # 检查是否支持该模型
45
  if model_name not in MODEL_MAPPING:
46
+ error_msg = f"不支持的模型: {model_name}"
47
+ logger.error(error_msg)
48
+ raise HTTPException(status_code=400, detail=error_msg)
49
  # 加载模型
50
+ logger.info(f"开始加载模型: {MODEL_MAPPING[model_name]}")
51
+ try:
52
+ models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
53
+ logger.info(f"模型 {model_name} 加载成功")
54
+ except Exception as e:
55
+ error_msg = f"加载模型 {model_name} 失败: {str(e)}"
56
+ logger.error(error_msg)
57
+ raise HTTPException(status_code=500, detail=error_msg)
58
  return models[model_name]
59
 
60
  # 验证API密钥
61
  def verify_api_key(authorization: Optional[str] = None):
62
+ logger.info("验证API密钥")
63
  if not authorization or not authorization.startswith("Bearer "):
64
+ logger.warning("未提供有效的API密钥格式")
65
  raise HTTPException(status_code=401, detail="未提供有效的API密钥")
66
  api_key = authorization[len("Bearer "):]
67
  if api_key != os.getenv("API_KEY"):
68
+ logger.warning("无效的API密钥")
69
  raise HTTPException(status_code=401, detail="无效的API密钥")
70
+ logger.info("API密钥验证通过")
71
  return True
72
 
73
  # 请求体模型(对齐OpenAI格式)
 
93
  request: EmbeddingRequest,
94
  _: bool = Depends(verify_api_key)
95
  ):
96
+ logger.info(f"收到嵌入请求,模型: {request.model}, 输入类型: {type(request.input)}")
97
  try:
98
  # 获取模型
99
  model = get_model(request.model)
100
 
101
  # 处理输入(支持单文本或文本列表)
102
  inputs = [request.input] if isinstance(request.input, str) else request.input
103
+ logger.info(f"处理输入,文本数量: {len(inputs)}")
104
 
105
  # 计算嵌入
106
+ logger.info("开始计算嵌入")
107
  embeddings = model.encode(inputs, normalize_embeddings=True)
108
+ logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}")
109
 
110
  # 构建响应
111
  data = [
 
115
 
116
  # 估算token数(简单近似:每个单词约1 token)
117
  prompt_tokens = sum(len(text.split()) for text in inputs)
118
+ logger.info(f"估算token数: {prompt_tokens}")
119
 
120
  return EmbeddingResponse(
121
  data=data,
 
123
  usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
124
  )
125
  except Exception as e:
126
+ error_msg = f"处理嵌入请求时发生错误: {str(e)}"
127
+ logger.error(error_msg)
128
+ raise HTTPException(status_code=500, detail=error_msg)
129
 
130
  # 健康检查接口
131
  @app.get("/health")
132
  async def health_check():
133
+ logger.info("健康检查请求")
134
  return {"status": "healthy", "models": list(MODEL_MAPPING.keys())}
135
 
136
  if __name__ == "__main__":
137
  import uvicorn
138
+ logger.info("启动服务")
139
+ uvicorn.run(app, host="0.0.0.0", port=7860)