🐛 Bug: Fix the bug where vercel cannot set app.state.config.
Browse files- README.md +2 -0
- README_CN.md +1 -1
- main.py +58 -67
README.md
CHANGED
|
@@ -192,6 +192,8 @@ There are other statistical data that you can query yourself by writing SQL in t
|
|
| 192 |
|
| 193 |
[](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
| 194 |
|
|
|
|
|
|
|
| 195 |
## Docker local deployment
|
| 196 |
|
| 197 |
Start the container
|
|
|
|
| 192 |
|
| 193 |
[](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
| 194 |
|
| 195 |
+
After clicking the one-click deployment button, set the environment variable `CONFIG_URL` to the direct link of the configuration file, and set `DISABLE_DATABASE` to true, then click Create to create the project.
|
| 196 |
+
|
| 197 |
## Docker local deployment
|
| 198 |
|
| 199 |
Start the container
|
README_CN.md
CHANGED
|
@@ -192,7 +192,7 @@ yym68686/uni-api:latest
|
|
| 192 |
|
| 193 |
[](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
| 194 |
|
| 195 |
-
点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链,然后点击 Create 创建项目。
|
| 196 |
|
| 197 |
## Docker 本地部署
|
| 198 |
|
|
|
|
| 192 |
|
| 193 |
[](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Fyym68686%2Funi-api%2Ftree%2Fmain&env=CONFIG_URL,DISABLE_DATABASE&project-name=uni-api-vercel&repository-name=uni-api-vercel)
|
| 194 |
|
| 195 |
+
点击上面的一键部署按钮后,设置环境变量 `CONFIG_URL` 为配置文件的直链, `DISABLE_DATABASE` 为 true,然后点击 Create 创建项目。
|
| 196 |
|
| 197 |
## Docker 本地部署
|
| 198 |
|
main.py
CHANGED
|
@@ -106,17 +106,6 @@ async def lifespan(app: FastAPI):
|
|
| 106 |
verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
|
| 107 |
follow_redirects=True, # 自动跟随重定向
|
| 108 |
)
|
| 109 |
-
# app.state.client = httpx.AsyncClient(timeout=timeout)
|
| 110 |
-
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
|
| 111 |
-
|
| 112 |
-
for item in app.state.api_keys_db:
|
| 113 |
-
if item.get("role") == "admin":
|
| 114 |
-
app.state.admin_api_key = item.get("api")
|
| 115 |
-
if not hasattr(app.state, "admin_api_key"):
|
| 116 |
-
if len(app.state.api_keys_db) >= 1:
|
| 117 |
-
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
| 118 |
-
else:
|
| 119 |
-
raise Exception("No admin API key found")
|
| 120 |
|
| 121 |
yield
|
| 122 |
# 关闭时的代码
|
|
@@ -224,6 +213,41 @@ def calculate_cost(model: str, input_tokens: int, output_tokens: int) -> Decimal
|
|
| 224 |
# 返回精确到15位小数的结果
|
| 225 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
class LoggingStreamingResponse(Response):
|
| 228 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
| 229 |
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
|
|
@@ -263,31 +287,14 @@ class LoggingStreamingResponse(Response):
|
|
| 263 |
|
| 264 |
process_time = time() - self.current_info["start_time"]
|
| 265 |
self.current_info["process_time"] = process_time
|
| 266 |
-
await
|
| 267 |
-
|
| 268 |
-
async def update_stats(self):
|
| 269 |
-
# 这里添加更新数据库的逻辑
|
| 270 |
-
# print("current_info2")
|
| 271 |
-
if DISABLE_DATABASE:
|
| 272 |
-
return
|
| 273 |
-
async with async_session() as session:
|
| 274 |
-
async with session.begin():
|
| 275 |
-
try:
|
| 276 |
-
columns = [column.key for column in RequestStat.__table__.columns]
|
| 277 |
-
filtered_info = {k: v for k, v in self.current_info.items() if k in columns}
|
| 278 |
-
new_request_stat = RequestStat(**filtered_info)
|
| 279 |
-
session.add(new_request_stat)
|
| 280 |
-
await session.commit()
|
| 281 |
-
except Exception as e:
|
| 282 |
-
await session.rollback()
|
| 283 |
-
logger.error(f"Error updating stats: {str(e)}")
|
| 284 |
|
| 285 |
async def _logging_iterator(self):
|
| 286 |
try:
|
| 287 |
async for chunk in self.body_iterator:
|
| 288 |
if isinstance(chunk, str):
|
| 289 |
chunk = chunk.encode('utf-8')
|
| 290 |
-
line = chunk.decode()
|
| 291 |
if is_debug:
|
| 292 |
logger.info(f"{line}")
|
| 293 |
if line.startswith("data:"):
|
|
@@ -435,41 +442,6 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
| 435 |
# print("current_request_info", current_request_info)
|
| 436 |
request_info.reset(current_request_info)
|
| 437 |
|
| 438 |
-
async def update_stats(self, current_info):
|
| 439 |
-
if DISABLE_DATABASE:
|
| 440 |
-
return
|
| 441 |
-
# 这里添加更新数据库的逻辑
|
| 442 |
-
async with async_session() as session:
|
| 443 |
-
async with session.begin():
|
| 444 |
-
try:
|
| 445 |
-
columns = [column.key for column in RequestStat.__table__.columns]
|
| 446 |
-
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
| 447 |
-
new_request_stat = RequestStat(**filtered_info)
|
| 448 |
-
session.add(new_request_stat)
|
| 449 |
-
await session.commit()
|
| 450 |
-
except Exception as e:
|
| 451 |
-
await session.rollback()
|
| 452 |
-
logger.error(f"Error updating stats: {str(e)}")
|
| 453 |
-
|
| 454 |
-
async def update_channel_stats(self, request_id, provider, model, api_key, success):
|
| 455 |
-
if DISABLE_DATABASE:
|
| 456 |
-
return
|
| 457 |
-
async with async_session() as session:
|
| 458 |
-
async with session.begin():
|
| 459 |
-
try:
|
| 460 |
-
channel_stat = ChannelStat(
|
| 461 |
-
request_id=request_id,
|
| 462 |
-
provider=provider,
|
| 463 |
-
model=model,
|
| 464 |
-
api_key=api_key,
|
| 465 |
-
success=success,
|
| 466 |
-
)
|
| 467 |
-
session.add(channel_stat)
|
| 468 |
-
await session.commit()
|
| 469 |
-
except Exception as e:
|
| 470 |
-
await session.rollback()
|
| 471 |
-
logger.error(f"Error updating channel stats: {str(e)}")
|
| 472 |
-
|
| 473 |
async def moderate_content(self, content, token):
|
| 474 |
moderation_request = ModerationRequest(input=content)
|
| 475 |
|
|
@@ -500,6 +472,23 @@ app.add_middleware(
|
|
| 500 |
|
| 501 |
app.add_middleware(StatsMiddleware)
|
| 502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
# 在 process_request 函数中更新成功和失败计数
|
| 504 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
| 505 |
url = provider['base_url']
|
|
@@ -581,14 +570,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
| 581 |
# response = JSONResponse(first_element)
|
| 582 |
|
| 583 |
# 更新成功计数和首次响应时间
|
| 584 |
-
await
|
|
|
|
| 585 |
current_info["first_response_time"] = first_response_time
|
| 586 |
current_info["success"] = True
|
| 587 |
current_info["provider"] = provider['provider']
|
| 588 |
|
| 589 |
return response
|
| 590 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
| 591 |
-
await
|
|
|
|
| 592 |
|
| 593 |
raise e
|
| 594 |
|
|
|
|
| 106 |
verify=True, # 保持 SSL 验证(如需禁用,设为 False,但不建议)
|
| 107 |
follow_redirects=True, # 自动跟随重定向
|
| 108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
yield
|
| 111 |
# 关闭时的代码
|
|
|
|
| 213 |
# 返回精确到15位小数的结果
|
| 214 |
return total_cost.quantize(Decimal('0.000000000000001'))
|
| 215 |
|
| 216 |
+
async def update_stats(current_info):
|
| 217 |
+
if DISABLE_DATABASE:
|
| 218 |
+
return
|
| 219 |
+
# 这里添加更新数据库的逻辑
|
| 220 |
+
async with async_session() as session:
|
| 221 |
+
async with session.begin():
|
| 222 |
+
try:
|
| 223 |
+
columns = [column.key for column in RequestStat.__table__.columns]
|
| 224 |
+
filtered_info = {k: v for k, v in current_info.items() if k in columns}
|
| 225 |
+
new_request_stat = RequestStat(**filtered_info)
|
| 226 |
+
session.add(new_request_stat)
|
| 227 |
+
await session.commit()
|
| 228 |
+
except Exception as e:
|
| 229 |
+
await session.rollback()
|
| 230 |
+
logger.error(f"Error updating stats: {str(e)}")
|
| 231 |
+
|
| 232 |
+
async def update_channel_stats(request_id, provider, model, api_key, success):
|
| 233 |
+
if DISABLE_DATABASE:
|
| 234 |
+
return
|
| 235 |
+
async with async_session() as session:
|
| 236 |
+
async with session.begin():
|
| 237 |
+
try:
|
| 238 |
+
channel_stat = ChannelStat(
|
| 239 |
+
request_id=request_id,
|
| 240 |
+
provider=provider,
|
| 241 |
+
model=model,
|
| 242 |
+
api_key=api_key,
|
| 243 |
+
success=success,
|
| 244 |
+
)
|
| 245 |
+
session.add(channel_stat)
|
| 246 |
+
await session.commit()
|
| 247 |
+
except Exception as e:
|
| 248 |
+
await session.rollback()
|
| 249 |
+
logger.error(f"Error updating channel stats: {str(e)}")
|
| 250 |
+
|
| 251 |
class LoggingStreamingResponse(Response):
|
| 252 |
def __init__(self, content, status_code=200, headers=None, media_type=None, current_info=None):
|
| 253 |
super().__init__(content=None, status_code=status_code, headers=headers, media_type=media_type)
|
|
|
|
| 287 |
|
| 288 |
process_time = time() - self.current_info["start_time"]
|
| 289 |
self.current_info["process_time"] = process_time
|
| 290 |
+
await update_stats(self.current_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
async def _logging_iterator(self):
|
| 293 |
try:
|
| 294 |
async for chunk in self.body_iterator:
|
| 295 |
if isinstance(chunk, str):
|
| 296 |
chunk = chunk.encode('utf-8')
|
| 297 |
+
line = chunk.decode('utf-8')
|
| 298 |
if is_debug:
|
| 299 |
logger.info(f"{line}")
|
| 300 |
if line.startswith("data:"):
|
|
|
|
| 442 |
# print("current_request_info", current_request_info)
|
| 443 |
request_info.reset(current_request_info)
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
async def moderate_content(self, content, token):
|
| 446 |
moderation_request = ModerationRequest(input=content)
|
| 447 |
|
|
|
|
| 472 |
|
| 473 |
app.add_middleware(StatsMiddleware)
|
| 474 |
|
| 475 |
+
@app.middleware("http")
|
| 476 |
+
async def ensure_config(request: Request, call_next):
|
| 477 |
+
if not hasattr(app.state, 'config'):
|
| 478 |
+
logger.warning("Config not found, attempting to reload")
|
| 479 |
+
app.state.config, app.state.api_keys_db, app.state.api_list = await load_config(app)
|
| 480 |
+
|
| 481 |
+
for item in app.state.api_keys_db:
|
| 482 |
+
if item.get("role") == "admin":
|
| 483 |
+
app.state.admin_api_key = item.get("api")
|
| 484 |
+
if not hasattr(app.state, "admin_api_key"):
|
| 485 |
+
if len(app.state.api_keys_db) >= 1:
|
| 486 |
+
app.state.admin_api_key = app.state.api_keys_db[0].get("api")
|
| 487 |
+
else:
|
| 488 |
+
raise Exception("No admin API key found")
|
| 489 |
+
|
| 490 |
+
return await call_next(request)
|
| 491 |
+
|
| 492 |
# 在 process_request 函数中更新成功和失败计数
|
| 493 |
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
|
| 494 |
url = provider['base_url']
|
|
|
|
| 570 |
# response = JSONResponse(first_element)
|
| 571 |
|
| 572 |
# 更新成功计数和首次响应时间
|
| 573 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
| 574 |
+
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
| 575 |
current_info["first_response_time"] = first_response_time
|
| 576 |
current_info["success"] = True
|
| 577 |
current_info["provider"] = provider['provider']
|
| 578 |
|
| 579 |
return response
|
| 580 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError) as e:
|
| 581 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
| 582 |
+
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
| 583 |
|
| 584 |
raise e
|
| 585 |
|