Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import redis
|
|
| 7 |
import time
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Optional
|
|
|
|
| 10 |
|
| 11 |
# ---------------------- 1. 初始化配置 ----------------------
|
| 12 |
app = FastAPI(title="Distributed Agent Controller", version="1.0")
|
|
@@ -14,36 +15,48 @@ app = FastAPI(title="Distributed Agent Controller", version="1.0")
|
|
| 14 |
# 跨域配置(适配Hugging Face Space环境)
|
| 15 |
app.add_middleware(
|
| 16 |
CORSMiddleware,
|
| 17 |
-
allow_origins=["*"], #
|
| 18 |
allow_credentials=True,
|
| 19 |
allow_methods=["*"],
|
| 20 |
allow_headers=["*"],
|
| 21 |
)
|
| 22 |
|
| 23 |
-
# Redis
|
| 24 |
-
def
|
|
|
|
| 25 |
try:
|
| 26 |
-
|
| 27 |
host=os.getenv("REDIS_HOST"),
|
| 28 |
port=int(os.getenv("REDIS_PORT", 6379)),
|
| 29 |
password=os.getenv("REDIS_PASSWORD"),
|
| 30 |
decode_responses=True,
|
| 31 |
socket_timeout=5,
|
| 32 |
-
retry_on_timeout=True
|
|
|
|
| 33 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
-
raise RuntimeError(f"Redis connection failed: {str(e)}")
|
| 36 |
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def verify_api_key(api_key: Optional[str] = None):
|
| 41 |
valid_key = os.getenv("CONTROLLER_API_KEY", "default-key")
|
| 42 |
if api_key != valid_key:
|
| 43 |
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 44 |
return api_key
|
| 45 |
|
| 46 |
-
#
|
| 47 |
PROCESSING_NODES = [
|
| 48 |
{
|
| 49 |
"url": os.getenv("NODE_1_URL", "https://your-username-node1.hf.space"),
|
|
@@ -61,7 +74,7 @@ PROCESSING_NODES = [
|
|
| 61 |
}
|
| 62 |
]
|
| 63 |
|
| 64 |
-
# ----------------------
|
| 65 |
class CreateTaskRequest(BaseModel):
|
| 66 |
user_id: str
|
| 67 |
task_name: str
|
|
@@ -74,7 +87,7 @@ class SendMessageRequest(BaseModel):
|
|
| 74 |
max_new_tokens: int = 512
|
| 75 |
temperature: float = 0.7
|
| 76 |
|
| 77 |
-
# ----------------------
|
| 78 |
async def check_node_health(node: Dict) -> bool:
|
| 79 |
"""检查节点健康状态,5分钟缓存结果"""
|
| 80 |
current_time = time.time()
|
|
@@ -93,35 +106,38 @@ async def check_node_health(node: Dict) -> bool:
|
|
| 93 |
return True
|
| 94 |
node["health"] = "dead"
|
| 95 |
return False
|
| 96 |
-
except:
|
| 97 |
node["health"] = "dead"
|
| 98 |
return False
|
| 99 |
|
| 100 |
def get_least_loaded_node() -> Optional[Dict]:
|
| 101 |
-
"""
|
| 102 |
healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
|
| 103 |
if not healthy_nodes:
|
| 104 |
return None
|
| 105 |
-
# 按当前负载排序,选择负载最低的
|
| 106 |
return min(healthy_nodes, key=lambda x: x["load"])
|
| 107 |
|
| 108 |
def task_key(user_id: str, task_id: str) -> str:
|
| 109 |
return f"user:{user_id}:task:{task_id}"
|
| 110 |
|
| 111 |
-
# ----------------------
|
| 112 |
@app.post("/task/create")
|
| 113 |
-
async def create_task(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
"""创建新任务并生成初始响应"""
|
| 115 |
task_id = f"task_{uuid.uuid4().hex[:8]}"
|
| 116 |
task_data = {
|
| 117 |
"task_id": task_id,
|
| 118 |
"name": req.task_name,
|
| 119 |
-
"created": int(time.time()),
|
| 120 |
-
"updated": int(time.time()),
|
| 121 |
-
"history": [{"role": "user", "content": req.initial_prompt
|
| 122 |
}
|
| 123 |
|
| 124 |
-
# 保存任务到Redis
|
| 125 |
redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
|
| 126 |
redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
|
| 127 |
|
|
@@ -132,26 +148,30 @@ async def create_task(req: CreateTaskRequest, api_key: str = Depends(verify_api_
|
|
| 132 |
|
| 133 |
# 调用节点生成代码
|
| 134 |
try:
|
| 135 |
-
node["load"] += 1
|
| 136 |
async with httpx.AsyncClient(timeout=60) as client:
|
| 137 |
resp = await client.post(
|
| 138 |
f"{node['url']}/generate/code",
|
| 139 |
json={"prompt": req.initial_prompt, "max_new_tokens": 512},
|
| 140 |
params={"api_key": node["api_key"]}
|
| 141 |
)
|
| 142 |
-
node["load"] -= 1
|
| 143 |
resp.raise_for_status()
|
| 144 |
except Exception as e:
|
| 145 |
node["load"] -= 1
|
| 146 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 147 |
|
| 148 |
-
#
|
|
|
|
| 149 |
result = resp.json()
|
| 150 |
-
history =
|
| 151 |
-
history.append({"role": "assistant", "content": result["code"]})
|
| 152 |
redis_client.hset(
|
| 153 |
task_key(req.user_id, task_id),
|
| 154 |
-
mapping={
|
|
|
|
|
|
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
return {
|
|
@@ -161,16 +181,21 @@ async def create_task(req: CreateTaskRequest, api_key: str = Depends(verify_api_
|
|
| 161 |
}
|
| 162 |
|
| 163 |
@app.post("/task/message")
|
| 164 |
-
async def send_message(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
"""向现有任务发送消息"""
|
|
|
|
| 166 |
key = task_key(req.user_id, req.task_id)
|
| 167 |
if not redis_client.exists(key):
|
| 168 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 169 |
|
| 170 |
-
#
|
| 171 |
task_data = redis_client.hgetall(key)
|
| 172 |
-
history =
|
| 173 |
-
history.append({"role": "user", "content": req.message})
|
| 174 |
|
| 175 |
# 构建上下文
|
| 176 |
context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
|
|
@@ -201,17 +226,25 @@ async def send_message(req: SendMessageRequest, api_key: str = Depends(verify_ap
|
|
| 201 |
|
| 202 |
# 更新历史
|
| 203 |
result = resp.json()
|
| 204 |
-
history.append({"role": "assistant", "content": result["code"]})
|
| 205 |
redis_client.hset(
|
| 206 |
key,
|
| 207 |
-
mapping={
|
|
|
|
|
|
|
|
|
|
| 208 |
)
|
| 209 |
|
| 210 |
return {"status": "success", "response": result["code"]}
|
| 211 |
|
| 212 |
@app.get("/user/tasks/{user_id}")
|
| 213 |
-
async def list_tasks(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"""获取用户所有任务"""
|
|
|
|
| 215 |
task_ids = redis_client.smembers(f"user:{user_id}:tasks")
|
| 216 |
tasks = []
|
| 217 |
for tid in task_ids:
|
|
@@ -222,7 +255,7 @@ async def list_tasks(user_id: str, api_key: str = Depends(verify_api_key)):
|
|
| 222 |
"name": data["name"],
|
| 223 |
"created": time.ctime(int(data["created"])),
|
| 224 |
"updated": time.ctime(int(data["updated"])),
|
| 225 |
-
"history_length": len(
|
| 226 |
})
|
| 227 |
return {"status": "success", "tasks": tasks}
|
| 228 |
|
|
@@ -231,11 +264,12 @@ async def node_health(api_key: str = Depends(verify_api_key)):
|
|
| 231 |
"""查看节点健康状态"""
|
| 232 |
status = []
|
| 233 |
for node in PROCESSING_NODES:
|
|
|
|
|
|
|
| 234 |
status.append({
|
| 235 |
"url": node["url"],
|
| 236 |
"health": node["health"],
|
| 237 |
"load": node["load"],
|
| 238 |
"last_check": time.ctime(node["last_check"]) if node["last_check"] else "Never"
|
| 239 |
})
|
| 240 |
-
return {"status": "success", "nodes": status}
|
| 241 |
-
|
|
|
|
| 7 |
import time
|
| 8 |
import os
|
| 9 |
from typing import Dict, List, Optional
|
| 10 |
+
from redis.connection import ConnectionPool # 引入连接池
|
| 11 |
|
| 12 |
# ---------------------- 1. 初始化配置 ----------------------
|
| 13 |
app = FastAPI(title="Distributed Agent Controller", version="1.0")
|
|
|
|
| 15 |
# 跨域配置(适配Hugging Face Space环境)
|
| 16 |
app.add_middleware(
|
| 17 |
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"], # 生产环境替换为具体域名
|
| 19 |
allow_credentials=True,
|
| 20 |
allow_methods=["*"],
|
| 21 |
allow_headers=["*"],
|
| 22 |
)
|
| 23 |
|
| 24 |
+
# ---------------------- 关键修复:Redis连接池 + 依赖注入 ----------------------
|
| 25 |
+
def create_redis_pool() -> ConnectionPool:
|
| 26 |
+
"""创建Redis连接池(替代全局客户端,避免析构问题)"""
|
| 27 |
try:
|
| 28 |
+
pool = ConnectionPool(
|
| 29 |
host=os.getenv("REDIS_HOST"),
|
| 30 |
port=int(os.getenv("REDIS_PORT", 6379)),
|
| 31 |
password=os.getenv("REDIS_PASSWORD"),
|
| 32 |
decode_responses=True,
|
| 33 |
socket_timeout=5,
|
| 34 |
+
retry_on_timeout=True,
|
| 35 |
+
max_connections=10 # 限制连接数,避免资源泄漏
|
| 36 |
)
|
| 37 |
+
# 验证连接池可用性
|
| 38 |
+
with redis.Redis(connection_pool=pool) as client:
|
| 39 |
+
client.ping()
|
| 40 |
+
return pool
|
| 41 |
except Exception as e:
|
| 42 |
+
raise RuntimeError(f"Redis connection pool failed: {str(e)}")
|
| 43 |
|
| 44 |
+
# 初始化连接池(全局仅创建一次)
|
| 45 |
+
redis_pool = create_redis_pool()
|
| 46 |
|
| 47 |
+
def get_redis_client() -> redis.Redis:
|
| 48 |
+
"""依赖注入:每次请求创建独立客户端,请求结束自动关闭"""
|
| 49 |
+
with redis.Redis(connection_pool=redis_pool) as client:
|
| 50 |
+
yield client # 用yield确保客户端使用后自动释放
|
| 51 |
+
|
| 52 |
+
# ---------------------- 2. API密钥验证 ----------------------
|
| 53 |
def verify_api_key(api_key: Optional[str] = None):
|
| 54 |
valid_key = os.getenv("CONTROLLER_API_KEY", "default-key")
|
| 55 |
if api_key != valid_key:
|
| 56 |
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 57 |
return api_key
|
| 58 |
|
| 59 |
+
# ---------------------- 3. 处理节点配置 ----------------------
|
| 60 |
PROCESSING_NODES = [
|
| 61 |
{
|
| 62 |
"url": os.getenv("NODE_1_URL", "https://your-username-node1.hf.space"),
|
|
|
|
| 74 |
}
|
| 75 |
]
|
| 76 |
|
| 77 |
+
# ---------------------- 4. 数据模型 ----------------------
|
| 78 |
class CreateTaskRequest(BaseModel):
|
| 79 |
user_id: str
|
| 80 |
task_name: str
|
|
|
|
| 87 |
max_new_tokens: int = 512
|
| 88 |
temperature: float = 0.7
|
| 89 |
|
| 90 |
+
# ---------------------- 5. 工具函数 ----------------------
|
| 91 |
async def check_node_health(node: Dict) -> bool:
|
| 92 |
"""检查节点健康状态,5分钟缓存结果"""
|
| 93 |
current_time = time.time()
|
|
|
|
| 106 |
return True
|
| 107 |
node["health"] = "dead"
|
| 108 |
return False
|
| 109 |
+
except Exception as e:
|
| 110 |
node["health"] = "dead"
|
| 111 |
return False
|
| 112 |
|
| 113 |
def get_least_loaded_node() -> Optional[Dict]:
|
| 114 |
+
"""选择负载最低的健康节点"""
|
| 115 |
healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
|
| 116 |
if not healthy_nodes:
|
| 117 |
return None
|
|
|
|
| 118 |
return min(healthy_nodes, key=lambda x: x["load"])
|
| 119 |
|
| 120 |
def task_key(user_id: str, task_id: str) -> str:
|
| 121 |
return f"user:{user_id}:task:{task_id}"
|
| 122 |
|
| 123 |
+
# ---------------------- 6. 核心接口(修复Redis调用逻辑) ----------------------
|
| 124 |
@app.post("/task/create")
|
| 125 |
+
async def create_task(
|
| 126 |
+
req: CreateTaskRequest,
|
| 127 |
+
api_key: str = Depends(verify_api_key),
|
| 128 |
+
redis_client: redis.Redis = Depends(get_redis_client) # 依赖注入Redis客户端
|
| 129 |
+
):
|
| 130 |
"""创建新任务并生成初始响应"""
|
| 131 |
task_id = f"task_{uuid.uuid4().hex[:8]}"
|
| 132 |
task_data = {
|
| 133 |
"task_id": task_id,
|
| 134 |
"name": req.task_name,
|
| 135 |
+
"created": str(int(time.time())), # 存储为字符串,避免Redis哈希类型转换问题
|
| 136 |
+
"updated": str(int(time.time())),
|
| 137 |
+
"history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"') # 转义双引号,避免JSON错误
|
| 138 |
}
|
| 139 |
|
| 140 |
+
# 保存任务到Redis(使用注入的客户端)
|
| 141 |
redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
|
| 142 |
redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
|
| 143 |
|
|
|
|
| 148 |
|
| 149 |
# 调用节点生成代码
|
| 150 |
try:
|
| 151 |
+
node["load"] += 1
|
| 152 |
async with httpx.AsyncClient(timeout=60) as client:
|
| 153 |
resp = await client.post(
|
| 154 |
f"{node['url']}/generate/code",
|
| 155 |
json={"prompt": req.initial_prompt, "max_new_tokens": 512},
|
| 156 |
params={"api_key": node["api_key"]}
|
| 157 |
)
|
| 158 |
+
node["load"] -= 1
|
| 159 |
resp.raise_for_status()
|
| 160 |
except Exception as e:
|
| 161 |
node["load"] -= 1
|
| 162 |
raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
|
| 163 |
|
| 164 |
+
# 更新任务历史(修复eval风险,用json.loads替代)
|
| 165 |
+
import json # 建议在文件顶部导入,此处为了突出修复点
|
| 166 |
result = resp.json()
|
| 167 |
+
history = json.loads(task_data["history"]) # 用json.loads替代eval,更安全
|
| 168 |
+
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 169 |
redis_client.hset(
|
| 170 |
task_key(req.user_id, task_id),
|
| 171 |
+
mapping={
|
| 172 |
+
"history": json.dumps(history), # 用json.dumps存储,避免格式错误
|
| 173 |
+
"updated": str(int(time.time()))
|
| 174 |
+
}
|
| 175 |
)
|
| 176 |
|
| 177 |
return {
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
@app.post("/task/message")
|
| 184 |
+
async def send_message(
|
| 185 |
+
req: SendMessageRequest,
|
| 186 |
+
api_key: str = Depends(verify_api_key),
|
| 187 |
+
redis_client: redis.Redis = Depends(get_redis_client)
|
| 188 |
+
):
|
| 189 |
"""向现有任务发送消息"""
|
| 190 |
+
import json
|
| 191 |
key = task_key(req.user_id, req.task_id)
|
| 192 |
if not redis_client.exists(key):
|
| 193 |
raise HTTPException(status_code=404, detail="Task not found")
|
| 194 |
|
| 195 |
+
# 获取历史对话(用json.loads替代eval)
|
| 196 |
task_data = redis_client.hgetall(key)
|
| 197 |
+
history = json.loads(task_data["history"])
|
| 198 |
+
history.append({"role": "user", "content": req.message.replace('"', '\\"')})
|
| 199 |
|
| 200 |
# 构建上下文
|
| 201 |
context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
|
|
|
|
| 226 |
|
| 227 |
# 更新历史
|
| 228 |
result = resp.json()
|
| 229 |
+
history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
|
| 230 |
redis_client.hset(
|
| 231 |
key,
|
| 232 |
+
mapping={
|
| 233 |
+
"history": json.dumps(history),
|
| 234 |
+
"updated": str(int(time.time()))
|
| 235 |
+
}
|
| 236 |
)
|
| 237 |
|
| 238 |
return {"status": "success", "response": result["code"]}
|
| 239 |
|
| 240 |
@app.get("/user/tasks/{user_id}")
|
| 241 |
+
async def list_tasks(
|
| 242 |
+
user_id: str,
|
| 243 |
+
api_key: str = Depends(verify_api_key),
|
| 244 |
+
redis_client: redis.Redis = Depends(get_redis_client)
|
| 245 |
+
):
|
| 246 |
"""获取用户所有任务"""
|
| 247 |
+
import json
|
| 248 |
task_ids = redis_client.smembers(f"user:{user_id}:tasks")
|
| 249 |
tasks = []
|
| 250 |
for tid in task_ids:
|
|
|
|
| 255 |
"name": data["name"],
|
| 256 |
"created": time.ctime(int(data["created"])),
|
| 257 |
"updated": time.ctime(int(data["updated"])),
|
| 258 |
+
"history_length": len(json.loads(data["history"]))
|
| 259 |
})
|
| 260 |
return {"status": "success", "tasks": tasks}
|
| 261 |
|
|
|
|
| 264 |
"""查看节点健康状态"""
|
| 265 |
status = []
|
| 266 |
for node in PROCESSING_NODES:
|
| 267 |
+
# 主动触发健康检查(更新最新状态)
|
| 268 |
+
await check_node_health(node)
|
| 269 |
status.append({
|
| 270 |
"url": node["url"],
|
| 271 |
"health": node["health"],
|
| 272 |
"load": node["load"],
|
| 273 |
"last_check": time.ctime(node["last_check"]) if node["last_check"] else "Never"
|
| 274 |
})
|
| 275 |
+
return {"status": "success", "nodes": status}
|
|
|