🐛 Bug: Fix the bug where the model is not persisted to the file after being automatically retrieved.
Browse files
main.py
CHANGED
|
@@ -18,7 +18,7 @@ from fastapi.exceptions import RequestValidationError
|
|
| 18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
-
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict
|
| 22 |
|
| 23 |
from collections import defaultdict
|
| 24 |
from typing import List, Dict, Union
|
|
@@ -1120,8 +1120,6 @@ async def frontend_rate_limit_dependency(request: Request, x_api_key: str = Depe
|
|
| 1120 |
|
| 1121 |
xue_initialize(tailwind=True)
|
| 1122 |
|
| 1123 |
-
API_YAML_PATH = "./api.yaml"
|
| 1124 |
-
|
| 1125 |
data_table_columns = [
|
| 1126 |
# {"label": "Status", "value": "status", "sortable": True},
|
| 1127 |
{"label": "Provider", "value": "provider", "sortable": True},
|
|
@@ -1500,10 +1498,6 @@ def update_row_data(row_id, updated_data):
|
|
| 1500 |
index = int(row_id)
|
| 1501 |
app.state.config["providers"][index] = updated_data
|
| 1502 |
|
| 1503 |
-
def save_api_yaml():
|
| 1504 |
-
with open(API_YAML_PATH, "w", encoding="utf-8") as f:
|
| 1505 |
-
yaml.dump(app.state.config, f)
|
| 1506 |
-
|
| 1507 |
@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
|
| 1508 |
async def submit_form(
|
| 1509 |
row_id: str,
|
|
@@ -1551,7 +1545,8 @@ async def submit_form(
|
|
| 1551 |
update_row_data(row_id, updated_data)
|
| 1552 |
|
| 1553 |
# 保存更新后的配置
|
| 1554 |
-
|
|
|
|
| 1555 |
|
| 1556 |
return await root()
|
| 1557 |
|
|
@@ -1564,7 +1559,8 @@ async def duplicate_row(row_id: str):
|
|
| 1564 |
app.state.config["providers"].insert(index + 1, new_data)
|
| 1565 |
|
| 1566 |
# 保存更新后的配置
|
| 1567 |
-
|
|
|
|
| 1568 |
|
| 1569 |
return await root()
|
| 1570 |
|
|
@@ -1574,7 +1570,8 @@ async def delete_row(row_id: str):
|
|
| 1574 |
del app.state.config["providers"][index]
|
| 1575 |
|
| 1576 |
# 保存更新后的配置
|
| 1577 |
-
|
|
|
|
| 1578 |
|
| 1579 |
return await root()
|
| 1580 |
|
|
|
|
| 18 |
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
|
| 19 |
from request import get_payload
|
| 20 |
from response import fetch_response, fetch_response_stream
|
| 21 |
+
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
|
| 22 |
|
| 23 |
from collections import defaultdict
|
| 24 |
from typing import List, Dict, Union
|
|
|
|
| 1120 |
|
| 1121 |
xue_initialize(tailwind=True)
|
| 1122 |
|
|
|
|
|
|
|
| 1123 |
data_table_columns = [
|
| 1124 |
# {"label": "Status", "value": "status", "sortable": True},
|
| 1125 |
{"label": "Provider", "value": "provider", "sortable": True},
|
|
|
|
| 1498 |
index = int(row_id)
|
| 1499 |
app.state.config["providers"][index] = updated_data
|
| 1500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
@frontend_router.post("/submit/{row_id}", response_class=HTMLResponse, dependencies=[Depends(frontend_rate_limit_dependency)])
|
| 1502 |
async def submit_form(
|
| 1503 |
row_id: str,
|
|
|
|
| 1545 |
update_row_data(row_id, updated_data)
|
| 1546 |
|
| 1547 |
# 保存更新后的配置
|
| 1548 |
+
if not DISABLE_DATABASE:
|
| 1549 |
+
save_api_yaml(app.state.config)
|
| 1550 |
|
| 1551 |
return await root()
|
| 1552 |
|
|
|
|
| 1559 |
app.state.config["providers"].insert(index + 1, new_data)
|
| 1560 |
|
| 1561 |
# 保存更新后的配置
|
| 1562 |
+
if not DISABLE_DATABASE:
|
| 1563 |
+
save_api_yaml(app.state.config)
|
| 1564 |
|
| 1565 |
return await root()
|
| 1566 |
|
|
|
|
| 1570 |
del app.state.config["providers"][index]
|
| 1571 |
|
| 1572 |
# 保存更新后的配置
|
| 1573 |
+
if not DISABLE_DATABASE:
|
| 1574 |
+
save_api_yaml(app.state.config)
|
| 1575 |
|
| 1576 |
return await root()
|
| 1577 |
|
utils.py
CHANGED
|
@@ -63,7 +63,18 @@ def update_initial_model(api_url, api):
|
|
| 63 |
traceback.print_exc()
|
| 64 |
return []
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
for index, provider in enumerate(config_data['providers']):
|
| 68 |
if provider.get('project_id'):
|
| 69 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
|
@@ -78,7 +89,11 @@ def update_config(config_data):
|
|
| 78 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
|
| 79 |
|
| 80 |
if not provider.get("model"):
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
if provider.get("tools") == None:
|
| 84 |
provider["tools"] = True
|
|
@@ -128,16 +143,12 @@ async def load_config(app=None):
|
|
| 128 |
follow_redirects=True, # 自动跟随重定向
|
| 129 |
)
|
| 130 |
|
| 131 |
-
from ruamel.yaml import YAML, YAMLError
|
| 132 |
-
yaml = YAML()
|
| 133 |
-
yaml.preserve_quotes = True
|
| 134 |
-
yaml.indent(mapping=2, sequence=4, offset=2)
|
| 135 |
try:
|
| 136 |
-
with open(
|
| 137 |
conf = yaml.load(file)
|
| 138 |
|
| 139 |
if conf:
|
| 140 |
-
config, api_keys_db, api_list = update_config(conf)
|
| 141 |
else:
|
| 142 |
logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
| 143 |
config, api_keys_db, api_list = {}, {}, []
|
|
@@ -166,7 +177,7 @@ async def load_config(app=None):
|
|
| 166 |
# 更新配置
|
| 167 |
# logger.info(config_data)
|
| 168 |
if config_data:
|
| 169 |
-
config, api_keys_db, api_list = update_config(config_data)
|
| 170 |
else:
|
| 171 |
logger.error(f"Error fetching or parsing config from {config_url}")
|
| 172 |
config, api_keys_db, api_list = {}, {}, []
|
|
|
|
| 63 |
traceback.print_exc()
|
| 64 |
return []
|
| 65 |
|
| 66 |
+
from ruamel.yaml import YAML, YAMLError
|
| 67 |
+
yaml = YAML()
|
| 68 |
+
yaml.preserve_quotes = True
|
| 69 |
+
yaml.indent(mapping=2, sequence=4, offset=2)
|
| 70 |
+
|
| 71 |
+
API_YAML_PATH = "./api.yaml"
|
| 72 |
+
|
| 73 |
+
def save_api_yaml(config_data):
|
| 74 |
+
with open(API_YAML_PATH, "w", encoding="utf-8") as f:
|
| 75 |
+
yaml.dump(config_data, f)
|
| 76 |
+
|
| 77 |
+
def update_config(config_data, use_config_url=False):
|
| 78 |
for index, provider in enumerate(config_data['providers']):
|
| 79 |
if provider.get('project_id'):
|
| 80 |
provider['base_url'] = 'https://aiplatform.googleapis.com/'
|
|
|
|
| 89 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(provider_api)
|
| 90 |
|
| 91 |
if not provider.get("model"):
|
| 92 |
+
model_list = update_initial_model(provider['base_url'], provider['api'])
|
| 93 |
+
if model_list:
|
| 94 |
+
provider["model"] = model_list
|
| 95 |
+
if not use_config_url:
|
| 96 |
+
save_api_yaml(config_data)
|
| 97 |
|
| 98 |
if provider.get("tools") == None:
|
| 99 |
provider["tools"] = True
|
|
|
|
| 143 |
follow_redirects=True, # 自动跟随重定向
|
| 144 |
)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
try:
|
| 147 |
+
with open(API_YAML_PATH, 'r', encoding='utf-8') as file:
|
| 148 |
conf = yaml.load(file)
|
| 149 |
|
| 150 |
if conf:
|
| 151 |
+
config, api_keys_db, api_list = update_config(conf, use_config_url=False)
|
| 152 |
else:
|
| 153 |
logger.error("配置文件 'api.yaml' 为空。请检查文件内容。")
|
| 154 |
config, api_keys_db, api_list = {}, {}, []
|
|
|
|
| 177 |
# 更新配置
|
| 178 |
# logger.info(config_data)
|
| 179 |
if config_data:
|
| 180 |
+
config, api_keys_db, api_list = update_config(config_data, use_config_url=True)
|
| 181 |
else:
|
| 182 |
logger.error(f"Error fetching or parsing config from {config_url}")
|
| 183 |
config, api_keys_db, api_list = {}, {}, []
|