Spaces:
Sleeping
Sleeping
File size: 9,124 Bytes
4c2a557 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | from typing import Optional, Callable, Any, Awaitable
from pydantic import Field, BaseModel
import requests
import time
from open_webui.utils.misc import get_last_assistant_message
import json
import os
class Filter:
class Valves(BaseModel):
API_ENDPOINT: str = Field(
default="", description="The base URL for the API endpoint."
)
API_KEY: str = Field(default="", description="API key for authentication.")
priority: int = Field(
default=5, description="Priority level for the filter operations."
)
def __init__(self):
self.type = "filter"
self.name = "OpenWebUI Monitor"
self.valves = self.Valves()
self.outage = False
self.start_time = None
self.inlet_temp = None
def _prepare_request_body(self, body: dict) -> dict:
"""Convert body and nested objects to JSON-serializable format"""
body_copy = body.copy()
if 'metadata' in body_copy and 'model' in body_copy['metadata']:
if hasattr(body_copy['metadata']['model'], 'model_dump'):
body_copy['metadata']['model'] = body_copy['metadata']['model'].model_dump()
return body_copy
def _prepare_user_dict(self, __user__: dict) -> dict:
"""将 __user__ 对象转换为可序列化的字典"""
user_dict = dict(__user__) # 创建副本以避免修改原始对象
# 如果存在 valves 且是 BaseModel 的实例,将其转换为字典
if "valves" in user_dict and hasattr(user_dict["valves"], "model_dump"):
user_dict["valves"] = user_dict["valves"].model_dump()
return user_dict
def _modify_outlet_body(self, body: dict) -> dict:
body_modify = dict(body)
last_message = body_modify["messages"][-1]
if "info" not in last_message and self.inlet_temp is not None:
body_modify["messages"][:-1] = self.inlet_temp["messages"]
return body_modify
def inlet(
self, body: dict, user: Optional[dict] = None, __user__: dict = {}
) -> dict:
self.start_time = time.time()
try:
post_url = f"{self.valves.API_ENDPOINT}/api/v1/inlet"
headers = {"Authorization": f"Bearer {self.valves.API_KEY}"}
# 使用 _prepare_user_dict 处理 __user__ 对象
user_dict = self._prepare_user_dict(__user__)
body_dict = self._prepare_request_body(body)
self.inlet_temp = body_dict
request_data = {
"user": user_dict,
"body": body_dict
}
response = requests.post(post_url, headers=headers, json=request_data)
if response.status_code == 401:
return body
response.raise_for_status()
response_data = response.json()
if not response_data.get("success"):
error_msg = response_data.get("error", "未知错误")
error_type = response_data.get("error_type", "UNKNOWN_ERROR")
raise Exception(f"请求失败: [{error_type}] {error_msg}")
self.outage = response_data.get("balance", 0) <= 0
if self.outage:
raise Exception(f"余额不足: 当前余额 `{response_data['balance']:.4f}`")
return body
except requests.exceptions.RequestException as e:
if (
isinstance(e, requests.exceptions.HTTPError)
and e.response.status_code == 401
):
return body
raise Exception(f"网络请求失败: {str(e)}")
except Exception as e:
raise Exception(f"处理请求时发生错误: {str(e)}")
async def outlet(
self,
body: dict,
user: Optional[dict] = None,
__user__: dict = {},
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
) -> dict:
if self.outage:
return body
try:
post_url = f"{self.valves.API_ENDPOINT}/api/v1/outlet"
headers = {"Authorization": f"Bearer {self.valves.API_KEY}"}
# 使用 _prepare_user_dict 处理 __user__ 对象
user_dict = self._prepare_user_dict(__user__)
body_dict = self._prepare_request_body(body)
body_modify = self._modify_outlet_body(body_dict)
request_data = {
"user": user_dict,
"body": body_modify
}
response = requests.post(post_url, headers=headers, json=request_data)
if response.status_code == 401:
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": "API密钥验证失败",
"done": True,
},
}
)
return body
response.raise_for_status()
result = response.json()
if not result.get("success"):
error_msg = result.get("error", "未知错误")
error_type = result.get("error_type", "UNKNOWN_ERROR")
raise Exception(f"请求失败: [{error_type}] {error_msg}")
# 获取统计数据
input_tokens = result["inputTokens"]
output_tokens = result["outputTokens"]
total_cost = result["totalCost"]
new_balance = result["newBalance"]
print(f"user_dict: {json.dumps(user_dict, indent=4)}")
print(f"inlet body: {json.dumps(body, indent=4)}")
# 从 body 中获取消息 ID
messages = body.get("messages", [])
message_id = messages[-1].get("id") if messages else None
if message_id: # 需要 message_id
# 构建统计信息字典
stats_data = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_cost": total_cost,
"new_balance": new_balance,
}
# 计算耗时(如果有start_time)
if self.start_time:
elapsed_time = time.time() - self.start_time
stats_data["elapsed_time"] = elapsed_time
# 计算每秒输出速度,使用三元运算符避免除以零
stats_data["tokens_per_sec"] = (
output_tokens / elapsed_time if elapsed_time > 0 else 0
)
# 指定目标目录路径
directory_path = "/app/backend/data/record"
# 确保目录存在
os.makedirs(directory_path, exist_ok=True)
# 构建文件路径
file_path = os.path.join(directory_path, f"{message_id}.json")
# 将统计信息写入 JSON 文件
with open(file_path, "w") as f:
json.dump(stats_data, f, indent=4)
else:
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"无法获取消息ID",
"done": True,
},
}
)
return body
except requests.exceptions.RequestException as e:
if (
isinstance(e, requests.exceptions.HTTPError)
and e.response.status_code == 401
):
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": "API密钥验证失败",
"done": True,
},
}
)
return body
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"网络请求失败: {str(e)}",
"done": True,
},
}
)
raise Exception(f"网络请求失败: {str(e)}")
except Exception as e:
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"错误: {str(e)}",
"done": True,
},
}
)
raise Exception(f"处理请求时发生错误: {str(e)}")
|