fiewolf1000 commited on
Commit
2a769ef
·
verified ·
1 Parent(s): db84bf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import os
@@ -6,7 +6,6 @@ import numpy as np
6
  from sentence_transformers import SentenceTransformer
7
  from typing import List, Optional
8
  import logging
9
- from fastapi import Header
10
 
11
  # 配置日志
12
  logging.basicConfig(
@@ -33,7 +32,7 @@ app.add_middleware(
33
  # 模型映射:OpenAI模型名 → 开源模型名
34
  MODEL_MAPPING = {
35
  "text-embedding-3-small": "BAAI/bge-small-en-v1.5",
36
- "text-embedding-3-large": "BAAI/bge-large-en-v1.5" # 新增大模型映射
37
  }
38
 
39
  # 加载模型(懒加载,首次请求时加载)
@@ -42,12 +41,10 @@ models = {}
42
  def get_model(model_name: str):
43
  logger.info(f"尝试获取模型: {model_name}")
44
  if model_name not in models:
45
- # 检查是否支持该模型
46
  if model_name not in MODEL_MAPPING:
47
  error_msg = f"不支持的模型: {model_name}"
48
  logger.error(error_msg)
49
  raise HTTPException(status_code=400, detail=error_msg)
50
- # 加载模型
51
  logger.info(f"开始加载模型: {MODEL_MAPPING[model_name]}")
52
  try:
53
  models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
@@ -59,8 +56,8 @@ def get_model(model_name: str):
59
  return models[model_name]
60
 
61
  # 验证API密钥
62
- def verify_api_key(authorization: Optional[str] = None):
63
- logger.info("验证API密钥")
64
  logger.info(f"Authorization头部内容: {authorization}")
65
  if not authorization or not authorization.startswith("Bearer "):
66
  logger.warning("未提供有效的API密钥格式")
@@ -72,13 +69,13 @@ def verify_api_key(authorization: Optional[str] = None):
72
  logger.info("API密钥验证通过")
73
  return True
74
 
75
- # 请求体模型(对齐OpenAI格式)
76
  class EmbeddingRequest(BaseModel):
77
  input: str or List[str]
78
  model: str
79
- encoding_format: Optional[str] = "float" # 仅支持float,忽略base64
80
 
81
- # 响应体模型(对齐OpenAI格式)
82
  class EmbeddingData(BaseModel):
83
  object: str = "embedding"
84
  embedding: List[float]
@@ -90,23 +87,26 @@ class EmbeddingResponse(BaseModel):
90
  model: str
91
  usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
92
 
93
- @app.post("/embeddings", response_model=EmbeddingResponse)
94
  async def create_embedding(
95
- request: Request, # 接收Request对象
96
- req: EmbeddingRequest,
97
- _: bool = Depends(verify_api_key)
98
  ):
99
- # 打印完整请求信息
100
  logger.info("\n===== 接收到的完整请求信息 =====")
101
  logger.info(f"请求方法: {request.method}")
102
  logger.info(f"请求URL: {request.url}")
103
  logger.info("请求头部:")
104
  for name, value in request.headers.items():
105
  logger.info(f" {name}: {value}")
106
- logger.info(f"请求体: {await request.body()}") # 打印原始请求体
107
  logger.info("===============================\n")
108
-
109
- # 原有逻辑保持不变
 
 
 
 
110
  logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
111
  try:
112
  model = get_model(req.model)
@@ -135,18 +135,6 @@ async def create_embedding(
135
  logger.error(error_msg)
136
  raise HTTPException(status_code=500, detail=error_msg)
137
 
138
- # 健康检查接口也打印完整请求
139
- @app.post("/v1/embeddings")
140
- async def v1_check(request: Request):
141
- logger.info("\n===== v1请求信息 =====")
142
- logger.info(f"请求方法: {request.method}")
143
- logger.info(f"请求URL: {request.url}")
144
- logger.info("请求头部:")
145
- for name, value in request.headers.items():
146
- logger.info(f" {name}: {value}")
147
- logger.info("===============================\n")
148
- return {"status": "healthy", "models": list(MODEL_MAPPING.keys())}
149
-
150
  @app.get("/health")
151
  async def health_check(request: Request):
152
  logger.info("\n===== 健康检查请求信息 =====")
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Request, Header
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import os
 
6
  from sentence_transformers import SentenceTransformer
7
  from typing import List, Optional
8
  import logging
 
9
 
10
  # 配置日志
11
  logging.basicConfig(
 
32
  # 模型映射:OpenAI模型名 → 开源模型名
33
  MODEL_MAPPING = {
34
  "text-embedding-3-small": "BAAI/bge-small-en-v1.5",
35
+ "text-embedding-3-large": "BAAI/bge-large-en-v1.5"
36
  }
37
 
38
  # 加载模型(懒加载,首次请求时加载)
 
41
  def get_model(model_name: str):
42
  logger.info(f"尝试获取模型: {model_name}")
43
  if model_name not in models:
 
44
  if model_name not in MODEL_MAPPING:
45
  error_msg = f"不支持的模型: {model_name}"
46
  logger.error(error_msg)
47
  raise HTTPException(status_code=400, detail=error_msg)
 
48
  logger.info(f"开始加载模型: {MODEL_MAPPING[model_name]}")
49
  try:
50
  models[model_name] = SentenceTransformer(MODEL_MAPPING[model_name])
 
56
  return models[model_name]
57
 
58
  # 验证API密钥
59
+ def verify_api_key(authorization: Optional[str] = Header(None)):
60
+ logger.info("执行API密钥验证")
61
  logger.info(f"Authorization头部内容: {authorization}")
62
  if not authorization or not authorization.startswith("Bearer "):
63
  logger.warning("未提供有效的API密钥格式")
 
69
  logger.info("API密钥验证通过")
70
  return True
71
 
72
+ # 请求体模型
73
  class EmbeddingRequest(BaseModel):
74
  input: str or List[str]
75
  model: str
76
+ encoding_format: Optional[str] = "float"
77
 
78
+ # 响应体模型
79
  class EmbeddingData(BaseModel):
80
  object: str = "embedding"
81
  embedding: List[float]
 
87
  model: str
88
  usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
89
 
90
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
91
  async def create_embedding(
92
+ request: Request,
93
+ req: EmbeddingRequest
 
94
  ):
95
+ # 打印完整请求信息(在验证之前)
96
  logger.info("\n===== 接收到的完整请求信息 =====")
97
  logger.info(f"请求方法: {request.method}")
98
  logger.info(f"请求URL: {request.url}")
99
  logger.info("请求头部:")
100
  for name, value in request.headers.items():
101
  logger.info(f" {name}: {value}")
102
+ logger.info(f"请求体: {await request.body()}")
103
  logger.info("===============================\n")
104
+
105
+ # 手动执行验证(在打印日志之后)
106
+ authorization = request.headers.get("Authorization")
107
+ verify_api_key(authorization)
108
+
109
+ # 原有嵌入处理逻辑
110
  logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
111
  try:
112
  model = get_model(req.model)
 
135
  logger.error(error_msg)
136
  raise HTTPException(status_code=500, detail=error_msg)
137
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  @app.get("/health")
139
  async def health_check(request: Request):
140
  logger.info("\n===== 健康检查请求信息 =====")