File size: 7,957 Bytes
6cfe55f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model
from app.db.provider_dao import get_enabled_providers
from app.enmus.exception import ProviderErrorEnum
from app.exceptions.provider import ProviderError
from app.gpt.gpt_factory import GPTFactory
from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider
from app.models.model_config import ModelConfig
from app.services.provider import ProviderService
from app.utils.logger import get_logger
logger=get_logger(__name__)
class ModelService:
@staticmethod
def _build_model_config(provider: dict) -> ModelConfig:
return ModelConfig(
api_key=provider["api_key"],
base_url=provider["base_url"],
provider=provider["name"],
model_name='',
name=provider["name"],
)
@staticmethod
def get_model_list(provider_id: int, verbose: bool = False):
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
return []
try:
config = ModelService._build_model_config(provider)
gpt = GPTFactory().from_config(config)
models = gpt.list_models()
if verbose:
print(f"[{provider['name']}] 模型列表: {models}")
return models
except Exception as e:
print(f"[{provider['name']}] 获取模型失败: {e}")
return []
@staticmethod
def get_all_models(verbose: bool = False):
try:
raw_models = get_all_models()
if verbose:
print(f"所有模型列表: {raw_models}")
return ModelService._format_models(raw_models)
except Exception as e:
print(f"获取所有模型失败: {e}")
return []
@staticmethod
def get_all_models_safe(verbose: bool = False):
try:
raw_models = get_all_models()
if verbose:
print(f"所有模型列表: {raw_models}")
return ModelService._format_models(raw_models)
except Exception as e:
print(f"获取所有模型失败: {e}")
return []
@staticmethod
def _format_models(raw_models: list) -> list:
"""
格式化模型列表
"""
formatted = []
for model in raw_models:
formatted.append({
"id": model.get("id"),
"provider_id": model.get("provider_id"),
"model_name": model.get("model_name"),
"created_at": model.get("created_at", None), # 如果有created_at字段
})
return formatted
@staticmethod
def _extract_remote_models(raw_models) -> list:
if raw_models is None:
return []
if isinstance(raw_models, dict):
raw_models = raw_models.get("data", raw_models.get("models", raw_models))
elif hasattr(raw_models, "data"):
raw_models = raw_models.data
if isinstance(raw_models, list):
return raw_models
return []
@staticmethod
def _serialize_remote_model(model) -> dict:
if isinstance(model, dict):
return model
if hasattr(model, "model_dump"):
return model.model_dump()
if hasattr(model, "dict"):
return model.dict()
model_id = getattr(model, "id", None)
if model_id:
return {
"id": model_id,
"object": getattr(model, "object", "model"),
"created": getattr(model, "created", None),
"owned_by": getattr(model, "owned_by", None),
}
return {}
@staticmethod
def get_enabled_models_by_provider( provider_id: str|int,):
from app.db.model_dao import get_models_by_provider
all_models = get_models_by_provider(provider_id)
enabled_models = all_models
return enabled_models
@staticmethod
def get_all_models_by_id(provider_id: str, verbose: bool = False):
try:
provider = ProviderService.get_provider_by_id(provider_id)
models = ModelService.get_model_list(provider["id"], verbose=verbose)
remote_models = ModelService._extract_remote_models(models)
serializable_models = [
item
for item in (ModelService._serialize_remote_model(model) for model in remote_models)
if item.get("id")
]
model_list = {
"models": serializable_models
}
logger.info(f"[{provider['name']}] 获取模型成功")
return model_list
except Exception as e:
# print(f"[{provider_id}] 获取模型失败: {e}")
logger.error(f"[{provider_id}] 获取模型失败: {e}")
return []
@staticmethod
def connect_test(id: str, model: str | None = None) -> bool:
"""连通性测试:发一条最小化 chat completion。
model 优先级:
1. 调用方显式传入(前端可在「模型选择」UI 里挑一个再测)
2. DB 中该 provider 已保存的第一个模型
3. 都没有 → 抛错让用户先加一个模型
"""
provider = ProviderService.get_provider_by_id(id)
if not provider:
raise ProviderError(
code=ProviderErrorEnum.NOT_FOUND.code,
message=ProviderErrorEnum.NOT_FOUND.message,
)
if not provider.get('api_key'):
raise ProviderError(
code=ProviderErrorEnum.NOT_FOUND.code,
message=ProviderErrorEnum.NOT_FOUND.message,
)
if not model:
saved_models = ModelService.get_enabled_models_by_provider(provider["id"])
if not saved_models:
raise ProviderError(
code=ProviderErrorEnum.WRONG_PARAMETER.code,
message="请先为该供应商添加至少一个模型再测试连通性",
)
model = saved_models[0]["model_name"]
ok = OpenAICompatibleProvider.test_connection(
api_key=provider.get('api_key'),
base_url=provider.get('base_url'),
model=model,
)
if ok:
return True
raise ProviderError(
code=ProviderErrorEnum.WRONG_PARAMETER.code,
message=ProviderErrorEnum.WRONG_PARAMETER.message,
)
@staticmethod
def delete_model_by_id( model_id: int) -> bool:
try:
delete_model(model_id)
return True
except Exception as e:
print(f"[{model_id}] <UNK>: {e}")
return False
@staticmethod
def add_new_model(provider_id: int, model_name: str) -> bool:
try:
# 先查供应商是否存在
provider = ProviderService.get_provider_by_id(provider_id)
if not provider:
print(f"供应商ID {provider_id} 不存在,无法添加模型")
return False
# 查询是否已存在同名模型
existing = get_model_by_provider_and_name(provider_id, model_name)
if existing:
print(f"模型 {model_name} 已存在于供应商ID {provider_id} 下,跳过插入")
return False
# 插入模型
insert_model(provider_id=provider_id, model_name=model_name)
print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}")
return True
except Exception as e:
print(f"添加模型失败: {e}")
return False
if __name__ == '__main__':
# 单个 Provider 测试
print(ModelService.get_model_list(1, verbose=True))
# 所有 Provider 模型测试
# print(ModelService.get_all_models(verbose=True))
|