Fix the bug of model matching format error
Browse files
main.py
CHANGED
|
@@ -21,7 +21,8 @@ from urllib.parse import urlparse
|
|
| 21 |
@asynccontextmanager
|
| 22 |
async def lifespan(app: FastAPI):
|
| 23 |
# 启动时的代码
|
| 24 |
-
|
|
|
|
| 25 |
yield
|
| 26 |
# 关闭时的代码
|
| 27 |
await app.state.client.aclose()
|
|
@@ -35,7 +36,20 @@ security = HTTPBearer()
|
|
| 35 |
def load_config():
|
| 36 |
try:
|
| 37 |
with open('api.yaml', 'r') as f:
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
except FileNotFoundError:
|
| 40 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
| 41 |
return []
|
|
@@ -43,19 +57,7 @@ def load_config():
|
|
| 43 |
print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
|
| 44 |
return []
|
| 45 |
|
| 46 |
-
config = load_config()
|
| 47 |
-
for index, provider in enumerate(config['providers']):
|
| 48 |
-
model_dict = {}
|
| 49 |
-
for model in provider['model']:
|
| 50 |
-
if type(model) == str:
|
| 51 |
-
model_dict[model] = model
|
| 52 |
-
if type(model) == dict:
|
| 53 |
-
model_dict.update({value: key for key, value in model.items()})
|
| 54 |
-
provider['model'] = model_dict
|
| 55 |
-
config['providers'][index] = provider
|
| 56 |
-
api_keys_db = config['api_keys']
|
| 57 |
-
api_list = [item["api"] for item in api_keys_db]
|
| 58 |
-
print(json.dumps(config, indent=4, ensure_ascii=False))
|
| 59 |
|
| 60 |
async def process_request(request: RequestModel, provider: Dict):
|
| 61 |
print("provider: ", provider['provider'])
|
|
@@ -102,7 +104,10 @@ class ModelRequestHandler:
|
|
| 102 |
if "/" in model:
|
| 103 |
provider_name = model.split("/")[0]
|
| 104 |
model = model.split("/")[1]
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
provider_rules.append(provider_name)
|
| 107 |
provider_list = []
|
| 108 |
for provider in config['providers']:
|
|
@@ -250,6 +255,11 @@ def generate_api_key():
|
|
| 250 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
| 251 |
return {"api_key": api_key}
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
if __name__ == '__main__':
|
| 254 |
import uvicorn
|
| 255 |
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
|
| 21 |
@asynccontextmanager
|
| 22 |
async def lifespan(app: FastAPI):
|
| 23 |
# 启动时的代码
|
| 24 |
+
timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0)
|
| 25 |
+
app.state.client = httpx.AsyncClient(timeout=timeout)
|
| 26 |
yield
|
| 27 |
# 关闭时的代码
|
| 28 |
await app.state.client.aclose()
|
|
|
|
| 36 |
def load_config():
|
| 37 |
try:
|
| 38 |
with open('api.yaml', 'r') as f:
|
| 39 |
+
conf = yaml.safe_load(f)
|
| 40 |
+
for index, provider in enumerate(conf['providers']):
|
| 41 |
+
model_dict = {}
|
| 42 |
+
for model in provider['model']:
|
| 43 |
+
if type(model) == str:
|
| 44 |
+
model_dict[model] = model
|
| 45 |
+
if type(model) == dict:
|
| 46 |
+
model_dict.update({value: key for key, value in model.items()})
|
| 47 |
+
provider['model'] = model_dict
|
| 48 |
+
conf['providers'][index] = provider
|
| 49 |
+
api_keys_db = conf['api_keys']
|
| 50 |
+
api_list = [item["api"] for item in api_keys_db]
|
| 51 |
+
print(json.dumps(conf, indent=4, ensure_ascii=False))
|
| 52 |
+
return conf, api_keys_db, api_list
|
| 53 |
except FileNotFoundError:
|
| 54 |
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
|
| 55 |
return []
|
|
|
|
| 57 |
print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
|
| 58 |
return []
|
| 59 |
|
| 60 |
+
config, api_keys_db, api_list = load_config()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
async def process_request(request: RequestModel, provider: Dict):
|
| 63 |
print("provider: ", provider['provider'])
|
|
|
|
| 104 |
if "/" in model:
|
| 105 |
provider_name = model.split("/")[0]
|
| 106 |
model = model.split("/")[1]
|
| 107 |
+
for provider in config['providers']:
|
| 108 |
+
if provider['provider'] == provider_name:
|
| 109 |
+
models_list = provider['model'].keys()
|
| 110 |
+
if (model and model_name == model) or (model == "*" and model_name in models_list):
|
| 111 |
provider_rules.append(provider_name)
|
| 112 |
provider_list = []
|
| 113 |
for provider in config['providers']:
|
|
|
|
| 255 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
| 256 |
return {"api_key": api_key}
|
| 257 |
|
| 258 |
+
async def on_fetch(request, env):
|
| 259 |
+
import asgi
|
| 260 |
+
|
| 261 |
+
return await asgi.fetch(app, request, env)
|
| 262 |
+
|
| 263 |
if __name__ == '__main__':
|
| 264 |
import uvicorn
|
| 265 |
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)
|