fiewolf1000 commited on
Commit
75911cc
·
verified ·
1 Parent(s): b532619

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -42,23 +42,31 @@ models = {}
42
 
43
  def get_model(model_name: str):
44
  logger.info(f"尝试获取模型: {model_name}")
 
 
45
  model_to_load = MODEL_MAPPING.get(model_name, model_name)
46
-
47
- # 提前检查模型是否已知支持列表中(可选)
48
- known_models = set(MODEL_MAPPING.keys()).union(set(MODEL_MAPPING.values()))
49
- if model_to_load not in known_models and not model_to_load.startswith(("BAAI/", "sentence-transformers/")):
50
- raise HTTPException(status_code=400, detail=f"不支持的模型: {model_name}")
51
-
 
 
 
52
  if model_name not in models:
53
  try:
54
  hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
55
- models[model_name] = SentenceTransformer(model_to_load, use_auth_token=hf_token)
 
 
 
56
  logger.info(f"模型 {model_name} 加载成功")
57
  except Exception as e:
58
- if "not a valid model identifier" in str(e):
59
- raise HTTPException(status_code=400, detail=f"不支持的模型: {model_name}")
60
- else:
61
- raise HTTPException(status_code=500, detail=f"加载模型失败: {str(e)}")
62
  return models[model_name]
63
 
64
 
 
42
 
43
  def get_model(model_name: str):
44
  logger.info(f"尝试获取模型: {model_name}")
45
+ # 1. 定义所有支持的模型(映射名 + 直接支持的模型名)
46
+ supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-*
47
  model_to_load = MODEL_MAPPING.get(model_name, model_name)
48
+
49
+ # 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400
50
+ known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型
51
+ if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)):
52
+ error_msg = f"不支持的模型: {model_name}"
53
+ logger.error(error_msg)
54
+ raise HTTPException(status_code=400, detail=error_msg)
55
+
56
+ # 3. 加载支持的模型(含合法机构前缀的模型)
57
  if model_name not in models:
58
  try:
59
  hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
60
+ models[model_name] = SentenceTransformer(
61
+ model_to_load,
62
+ use_auth_token=hf_token
63
+ )
64
  logger.info(f"模型 {model_name} 加载成功")
65
  except Exception as e:
66
+ # 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截
67
+ error_msg = f"加载模型 {model_name} 失败: {str(e)}"
68
+ logger.error(error_msg)
69
+ raise HTTPException(status_code=500, detail=error_msg)
70
  return models[model_name]
71
 
72