Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,45 +7,58 @@ import redis
|
|
| 7 |
import time
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Optional
|
| 10 |
-
from redis.connection import ConnectionPool # 引入连接
|
| 11 |
|
| 12 |
# ---------------------- 1. 初始化配置 ----------------------
|
| 13 |
app = FastAPI(title="Distributed Agent Controller", version="1.0")
|
| 14 |
|
| 15 |
-
# 跨域配置
|
| 16 |
app.add_middleware(
|
| 17 |
CORSMiddleware,
|
| 18 |
-
allow_origins=["*"],
|
| 19 |
allow_credentials=True,
|
| 20 |
allow_methods=["*"],
|
| 21 |
allow_headers=["*"],
|
| 22 |
)
|
| 23 |
|
| 24 |
-
# ----------------------
|
| 25 |
def create_redis_pool() -> ConnectionPool:
|
| 26 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
pool = ConnectionPool(
|
| 28 |
-
host=
|
| 29 |
-
port=
|
| 30 |
-
password=
|
| 31 |
decode_responses=True,
|
| 32 |
-
|
| 33 |
socket_timeout=5,
|
| 34 |
-
retry_on_timeout=True
|
|
|
|
| 35 |
)
|
|
|
|
|
|
|
| 36 |
with redis.Redis(connection_pool=pool) as client:
|
| 37 |
-
client.ping() #
|
|
|
|
| 38 |
return pool
|
| 39 |
except Exception as e:
|
| 40 |
-
raise RuntimeError(f"Redis
|
| 41 |
|
| 42 |
-
# 初始化连接池
|
| 43 |
redis_pool = create_redis_pool()
|
| 44 |
|
|
|
|
| 45 |
def get_redis_client() -> redis.Redis:
|
| 46 |
-
"""依赖注入:每次请求创建独立客户端,请求结束自动关闭"""
|
| 47 |
with redis.Redis(connection_pool=redis_pool) as client:
|
| 48 |
-
yield client
|
| 49 |
|
| 50 |
# ---------------------- 2. API密钥验证 ----------------------
|
| 51 |
def verify_api_key(api_key: Optional[str] = None):
|
|
@@ -61,7 +74,7 @@ PROCESSING_NODES = [
|
|
| 61 |
"api_key": os.getenv("NODE_API_KEY", "node-key"),
|
| 62 |
"health": "unknown",
|
| 63 |
"last_check": 0,
|
| 64 |
-
"load": 0
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"url": os.getenv("NODE_2_URL", "https://your-username-node2.hf.space"),
|
|
@@ -87,7 +100,6 @@ class SendMessageRequest(BaseModel):
|
|
| 87 |
|
| 88 |
# ---------------------- 5. 工具函数 ----------------------
|
| 89 |
async def check_node_health(node: Dict) -> bool:
|
| 90 |
-
"""检查节点健康状态,5分钟缓存结果"""
|
| 91 |
current_time = time.time()
|
| 92 |
if current_time - node["last_check"] < 300:
|
| 93 |
return node["health"] == "alive"
|
|
@@ -109,7 +121,6 @@ async def check_node_health(node: Dict) -> bool:
|
|
| 109 |
return False
|
| 110 |
|
| 111 |
def get_least_loaded_node() -> Optional[Dict]:
|
| 112 |
-
"""选择负载最低的健康节点"""
|
| 113 |
healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
|
| 114 |
if not healthy_nodes:
|
| 115 |
return None
|
|
@@ -118,33 +129,29 @@ def get_least_loaded_node() -> Optional[Dict]:
|
|
| 118 |
def task_key(user_id: str, task_id: str) -> str:
|
| 119 |
return f"user:{user_id}:task:{task_id}"
|
| 120 |
|
| 121 |
-
# ---------------------- 6. 核心接口
|
| 122 |
@app.post("/task/create")
|
| 123 |
async def create_task(
|
| 124 |
req: CreateTaskRequest,
|
| 125 |
api_key: str = Depends(verify_api_key),
|
| 126 |
-
redis_client: redis.Redis = Depends(get_redis_client)
|
| 127 |
):
|
| 128 |
-
"""创建新任务并生成初始响应"""
|
| 129 |
task_id = f"task_{uuid.uuid4().hex[:8]}"
|
| 130 |
task_data = {
|
| 131 |
"task_id": task_id,
|
| 132 |
"name": req.task_name,
|
| 133 |
-
"created": str(int(time.time())),
|
| 134 |
"updated": str(int(time.time())),
|
| 135 |
-
"history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"')
|
| 136 |
}
|
| 137 |
|
| 138 |
-
# 保存任务到Redis(使用注入的客户端)
|
| 139 |
redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
|
| 140 |
redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
|
| 141 |
|
| 142 |
-
# 分配节点处理
|
| 143 |
node = get_least_loaded_node()
|
| 144 |
if not node:
|
| 145 |
raise HTTPException(status_code=503, detail="No available nodes")
|
| 146 |
|
| 147 |
-
# 调用节点生成代码
|
| 148 |
try:
|
| 149 |
node["load"] += 1
|
| 150 |
async with httpx.AsyncClient(timeout=60) as client:
|
|
@@ -159,15 +166,14 @@ async def create_task(
|
|
| 159 |
node["load"] -= 1
|
| 160 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 161 |
|
| 162 |
-
|
| 163 |
-
import json # 建议在文件顶部导入,此处为了突出修复点
|
| 164 |
result = resp.json()
|
| 165 |
-
history = json.loads(task_data["history"])
|
| 166 |
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 167 |
redis_client.hset(
|
| 168 |
task_key(req.user_id, task_id),
|
| 169 |
mapping={
|
| 170 |
-
"history": json.dumps(history),
|
| 171 |
"updated": str(int(time.time()))
|
| 172 |
}
|
| 173 |
)
|
|
@@ -184,26 +190,21 @@ async def send_message(
|
|
| 184 |
api_key: str = Depends(verify_api_key),
|
| 185 |
redis_client: redis.Redis = Depends(get_redis_client)
|
| 186 |
):
|
| 187 |
-
"""向现有任务发送消息"""
|
| 188 |
import json
|
| 189 |
key = task_key(req.user_id, req.task_id)
|
| 190 |
if not redis_client.exists(key):
|
| 191 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 192 |
|
| 193 |
-
# 获取历史对话(用json.loads替代eval)
|
| 194 |
task_data = redis_client.hgetall(key)
|
| 195 |
history = json.loads(task_data["history"])
|
| 196 |
history.append({"role": "user", "content": req.message.replace('"', '\\"')})
|
| 197 |
|
| 198 |
-
# 构建上下文
|
| 199 |
context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
|
| 200 |
|
| 201 |
-
# 分配节点处理
|
| 202 |
node = get_least_loaded_node()
|
| 203 |
if not node:
|
| 204 |
raise HTTPException(status_code=503, detail="No available nodes")
|
| 205 |
|
| 206 |
-
# 调用节点生成响应
|
| 207 |
try:
|
| 208 |
node["load"] += 1
|
| 209 |
async with httpx.AsyncClient(timeout=60) as client:
|
|
@@ -222,7 +223,6 @@ async def send_message(
|
|
| 222 |
node["load"] -= 1
|
| 223 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 224 |
|
| 225 |
-
# 更新历史
|
| 226 |
result = resp.json()
|
| 227 |
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 228 |
redis_client.hset(
|
|
@@ -241,7 +241,6 @@ async def list_tasks(
|
|
| 241 |
api_key: str = Depends(verify_api_key),
|
| 242 |
redis_client: redis.Redis = Depends(get_redis_client)
|
| 243 |
):
|
| 244 |
-
"""获取用户所有任务"""
|
| 245 |
import json
|
| 246 |
task_ids = redis_client.smembers(f"user:{user_id}:tasks")
|
| 247 |
tasks = []
|
|
@@ -259,10 +258,8 @@ async def list_tasks(
|
|
| 259 |
|
| 260 |
@app.get("/nodes/health")
|
| 261 |
async def node_health(api_key: str = Depends(verify_api_key)):
|
| 262 |
-
"""查看节点健康状态"""
|
| 263 |
status = []
|
| 264 |
for node in PROCESSING_NODES:
|
| 265 |
-
# 主动触发健康检查(更新最新状态)
|
| 266 |
await check_node_health(node)
|
| 267 |
status.append({
|
| 268 |
"url": node["url"],
|
|
|
|
| 7 |
import time
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Optional
|
| 10 |
+
from redis.connection import ConnectionPool, SSLConnection # 引入SSL连接类
|
| 11 |
|
| 12 |
# ---------------------- 1. 初始化配置 ----------------------
|
| 13 |
app = FastAPI(title="Distributed Agent Controller", version="1.0")
|
| 14 |
|
| 15 |
+
# 跨域配置
|
| 16 |
app.add_middleware(
|
| 17 |
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"],
|
| 19 |
allow_credentials=True,
|
| 20 |
allow_methods=["*"],
|
| 21 |
allow_headers=["*"],
|
| 22 |
)
|
| 23 |
|
| 24 |
+
# ---------------------- 修复Redis连接池(兼容低版本redis-py) ----------------------
|
| 25 |
def create_redis_pool() -> ConnectionPool:
|
| 26 |
try:
|
| 27 |
+
# 从环境变量获取配置(必须在Space中正确设置)
|
| 28 |
+
redis_host = os.getenv("REDIS_HOST")
|
| 29 |
+
redis_port = int(os.getenv("REDIS_PORT", 6380))
|
| 30 |
+
redis_password = os.getenv("REDIS_PASSWORD")
|
| 31 |
+
|
| 32 |
+
if not all([redis_host, redis_password]):
|
| 33 |
+
raise ValueError("REDIS_HOST和REDIS_PASSWORD必须配置")
|
| 34 |
+
|
| 35 |
+
# 关键修复:使用SSLConnection类,替代ssl=True参数(兼容低版本redis-py)
|
| 36 |
pool = ConnectionPool(
|
| 37 |
+
host=redis_host,
|
| 38 |
+
port=redis_port,
|
| 39 |
+
password=redis_password,
|
| 40 |
decode_responses=True,
|
| 41 |
+
connection_class=SSLConnection, # 强制使用SSL加密连接
|
| 42 |
socket_timeout=5,
|
| 43 |
+
retry_on_timeout=True,
|
| 44 |
+
max_connections=10
|
| 45 |
)
|
| 46 |
+
|
| 47 |
+
# 验证连接
|
| 48 |
with redis.Redis(connection_pool=pool) as client:
|
| 49 |
+
client.ping() # 成功会返回True
|
| 50 |
+
print("Redis连接池初始化成功")
|
| 51 |
return pool
|
| 52 |
except Exception as e:
|
| 53 |
+
raise RuntimeError(f"Redis连接池初始化失败:{str(e)}")
|
| 54 |
|
| 55 |
+
# 初始化连接池
|
| 56 |
redis_pool = create_redis_pool()
|
| 57 |
|
| 58 |
+
# 依赖注入:获取Redis客户端
|
| 59 |
def get_redis_client() -> redis.Redis:
|
|
|
|
| 60 |
with redis.Redis(connection_pool=redis_pool) as client:
|
| 61 |
+
yield client
|
| 62 |
|
| 63 |
# ---------------------- 2. API密钥验证 ----------------------
|
| 64 |
def verify_api_key(api_key: Optional[str] = None):
|
|
|
|
| 74 |
"api_key": os.getenv("NODE_API_KEY", "node-key"),
|
| 75 |
"health": "unknown",
|
| 76 |
"last_check": 0,
|
| 77 |
+
"load": 0
|
| 78 |
},
|
| 79 |
{
|
| 80 |
"url": os.getenv("NODE_2_URL", "https://your-username-node2.hf.space"),
|
|
|
|
| 100 |
|
| 101 |
# ---------------------- 5. 工具函数 ----------------------
|
| 102 |
async def check_node_health(node: Dict) -> bool:
|
|
|
|
| 103 |
current_time = time.time()
|
| 104 |
if current_time - node["last_check"] < 300:
|
| 105 |
return node["health"] == "alive"
|
|
|
|
| 121 |
return False
|
| 122 |
|
| 123 |
def get_least_loaded_node() -> Optional[Dict]:
|
|
|
|
| 124 |
healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
|
| 125 |
if not healthy_nodes:
|
| 126 |
return None
|
|
|
|
| 129 |
def task_key(user_id: str, task_id: str) -> str:
|
| 130 |
return f"user:{user_id}:task:{task_id}"
|
| 131 |
|
| 132 |
+
# ---------------------- 6. 核心接口 ----------------------
|
| 133 |
@app.post("/task/create")
|
| 134 |
async def create_task(
|
| 135 |
req: CreateTaskRequest,
|
| 136 |
api_key: str = Depends(verify_api_key),
|
| 137 |
+
redis_client: redis.Redis = Depends(get_redis_client)
|
| 138 |
):
|
|
|
|
| 139 |
task_id = f"task_{uuid.uuid4().hex[:8]}"
|
| 140 |
task_data = {
|
| 141 |
"task_id": task_id,
|
| 142 |
"name": req.task_name,
|
| 143 |
+
"created": str(int(time.time())),
|
| 144 |
"updated": str(int(time.time())),
|
| 145 |
+
"history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"')
|
| 146 |
}
|
| 147 |
|
|
|
|
| 148 |
redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
|
| 149 |
redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
|
| 150 |
|
|
|
|
| 151 |
node = get_least_loaded_node()
|
| 152 |
if not node:
|
| 153 |
raise HTTPException(status_code=503, detail="No available nodes")
|
| 154 |
|
|
|
|
| 155 |
try:
|
| 156 |
node["load"] += 1
|
| 157 |
async with httpx.AsyncClient(timeout=60) as client:
|
|
|
|
| 166 |
node["load"] -= 1
|
| 167 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 168 |
|
| 169 |
+
import json
|
|
|
|
| 170 |
result = resp.json()
|
| 171 |
+
history = json.loads(task_data["history"])
|
| 172 |
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 173 |
redis_client.hset(
|
| 174 |
task_key(req.user_id, task_id),
|
| 175 |
mapping={
|
| 176 |
+
"history": json.dumps(history),
|
| 177 |
"updated": str(int(time.time()))
|
| 178 |
}
|
| 179 |
)
|
|
|
|
| 190 |
api_key: str = Depends(verify_api_key),
|
| 191 |
redis_client: redis.Redis = Depends(get_redis_client)
|
| 192 |
):
|
|
|
|
| 193 |
import json
|
| 194 |
key = task_key(req.user_id, req.task_id)
|
| 195 |
if not redis_client.exists(key):
|
| 196 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 197 |
|
|
|
|
| 198 |
task_data = redis_client.hgetall(key)
|
| 199 |
history = json.loads(task_data["history"])
|
| 200 |
history.append({"role": "user", "content": req.message.replace('"', '\\"')})
|
| 201 |
|
|
|
|
| 202 |
context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
|
| 203 |
|
|
|
|
| 204 |
node = get_least_loaded_node()
|
| 205 |
if not node:
|
| 206 |
raise HTTPException(status_code=503, detail="No available nodes")
|
| 207 |
|
|
|
|
| 208 |
try:
|
| 209 |
node["load"] += 1
|
| 210 |
async with httpx.AsyncClient(timeout=60) as client:
|
|
|
|
| 223 |
node["load"] -= 1
|
| 224 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 225 |
|
|
|
|
| 226 |
result = resp.json()
|
| 227 |
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 228 |
redis_client.hset(
|
|
|
|
| 241 |
api_key: str = Depends(verify_api_key),
|
| 242 |
redis_client: redis.Redis = Depends(get_redis_client)
|
| 243 |
):
|
|
|
|
| 244 |
import json
|
| 245 |
task_ids = redis_client.smembers(f"user:{user_id}:tasks")
|
| 246 |
tasks = []
|
|
|
|
| 258 |
|
| 259 |
@app.get("/nodes/health")
|
| 260 |
async def node_health(api_key: str = Depends(verify_api_key)):
|
|
|
|
| 261 |
status = []
|
| 262 |
for node in PROCESSING_NODES:
|
|
|
|
| 263 |
await check_node_health(node)
|
| 264 |
status.append({
|
| 265 |
"url": node["url"],
|