fiewolf1000 commited on
Commit
1f8792f
·
verified ·
1 Parent(s): 24ae713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -38
app.py CHANGED
@@ -7,45 +7,58 @@ import redis
7
  import time
8
  import os
9
  from typing import Dict, List, Optional
10
- from redis.connection import ConnectionPool # 引入连接
11
 
12
  # ---------------------- 1. 初始化配置 ----------------------
13
  app = FastAPI(title="Distributed Agent Controller", version="1.0")
14
 
15
- # 跨域配置(适配Hugging Face Space环境)
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"], # 生产环境替换为具体域名
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
- # ---------------------- 关键修复Redis连接池 + 依赖注入 ----------------------
25
  def create_redis_pool() -> ConnectionPool:
26
  try:
 
 
 
 
 
 
 
 
 
27
  pool = ConnectionPool(
28
- host=os.getenv("REDIS_HOST"), # 填 Upstash 的 Host(xxx.upstash.io)
29
- port=int(os.getenv("REDIS_PORT", 6380)), # 固定填 6380(加密端口)
30
- password=os.getenv("REDIS_PASSWORD"), # 填 Upstash 的实例密码
31
  decode_responses=True,
32
- ssl=True, # 必须开启 SSL,Upstash 强制加密连接
33
  socket_timeout=5,
34
- retry_on_timeout=True
 
35
  )
 
 
36
  with redis.Redis(connection_pool=pool) as client:
37
- client.ping() # 验证连接
 
38
  return pool
39
  except Exception as e:
40
- raise RuntimeError(f"Redis connection failed: {str(e)}")
41
 
42
- # 初始化连接池(全局仅创建一次)
43
  redis_pool = create_redis_pool()
44
 
 
45
  def get_redis_client() -> redis.Redis:
46
- """依赖注入:每次请求创建独立客户端,请求结束自动关闭"""
47
  with redis.Redis(connection_pool=redis_pool) as client:
48
- yield client # 用yield确保客户端使用后自动释放
49
 
50
  # ---------------------- 2. API密钥验证 ----------------------
51
  def verify_api_key(api_key: Optional[str] = None):
@@ -61,7 +74,7 @@ PROCESSING_NODES = [
61
  "api_key": os.getenv("NODE_API_KEY", "node-key"),
62
  "health": "unknown",
63
  "last_check": 0,
64
- "load": 0 # 节点负载计数
65
  },
66
  {
67
  "url": os.getenv("NODE_2_URL", "https://your-username-node2.hf.space"),
@@ -87,7 +100,6 @@ class SendMessageRequest(BaseModel):
87
 
88
  # ---------------------- 5. 工具函数 ----------------------
89
  async def check_node_health(node: Dict) -> bool:
90
- """检查节点健康状态,5分钟缓存结果"""
91
  current_time = time.time()
92
  if current_time - node["last_check"] < 300:
93
  return node["health"] == "alive"
@@ -109,7 +121,6 @@ async def check_node_health(node: Dict) -> bool:
109
  return False
110
 
111
  def get_least_loaded_node() -> Optional[Dict]:
112
- """选择负载最低的健康节点"""
113
  healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
114
  if not healthy_nodes:
115
  return None
@@ -118,33 +129,29 @@ def get_least_loaded_node() -> Optional[Dict]:
118
  def task_key(user_id: str, task_id: str) -> str:
119
  return f"user:{user_id}:task:{task_id}"
120
 
121
- # ---------------------- 6. 核心接口(修复Redis调用逻辑) ----------------------
122
  @app.post("/task/create")
123
  async def create_task(
124
  req: CreateTaskRequest,
125
  api_key: str = Depends(verify_api_key),
126
- redis_client: redis.Redis = Depends(get_redis_client) # 依赖注入Redis客户端
127
  ):
128
- """创建新任务并生成初始响应"""
129
  task_id = f"task_{uuid.uuid4().hex[:8]}"
130
  task_data = {
131
  "task_id": task_id,
132
  "name": req.task_name,
133
- "created": str(int(time.time())), # 存储为字符串,避免Redis哈希类型转换问题
134
  "updated": str(int(time.time())),
135
- "history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"') # 转义��引号,避免JSON错误
136
  }
137
 
138
- # 保存任务到Redis(使用注入的客户端)
139
  redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
140
  redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
141
 
142
- # 分配节点处理
143
  node = get_least_loaded_node()
144
  if not node:
145
  raise HTTPException(status_code=503, detail="No available nodes")
146
 
147
- # 调用节点生成代码
148
  try:
149
  node["load"] += 1
150
  async with httpx.AsyncClient(timeout=60) as client:
@@ -159,15 +166,14 @@ async def create_task(
159
  node["load"] -= 1
160
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
161
 
162
- # 更新任务历史(修复eval风险,用json.loads替代)
163
- import json # 建议在文件顶部导入,此处为了突出修复点
164
  result = resp.json()
165
- history = json.loads(task_data["history"]) # 用json.loads替代eval,更安全
166
  history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
167
  redis_client.hset(
168
  task_key(req.user_id, task_id),
169
  mapping={
170
- "history": json.dumps(history), # 用json.dumps存储,避免格式错误
171
  "updated": str(int(time.time()))
172
  }
173
  )
@@ -184,26 +190,21 @@ async def send_message(
184
  api_key: str = Depends(verify_api_key),
185
  redis_client: redis.Redis = Depends(get_redis_client)
186
  ):
187
- """向现有任务发送消息"""
188
  import json
189
  key = task_key(req.user_id, req.task_id)
190
  if not redis_client.exists(key):
191
  raise HTTPException(status_code=404, detail="Task not found")
192
 
193
- # 获取历史对话(用json.loads替代eval)
194
  task_data = redis_client.hgetall(key)
195
  history = json.loads(task_data["history"])
196
  history.append({"role": "user", "content": req.message.replace('"', '\\"')})
197
 
198
- # 构建上下文
199
  context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
200
 
201
- # 分配节点处理
202
  node = get_least_loaded_node()
203
  if not node:
204
  raise HTTPException(status_code=503, detail="No available nodes")
205
 
206
- # 调用节点生成响应
207
  try:
208
  node["load"] += 1
209
  async with httpx.AsyncClient(timeout=60) as client:
@@ -222,7 +223,6 @@ async def send_message(
222
  node["load"] -= 1
223
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
224
 
225
- # 更新历史
226
  result = resp.json()
227
  history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
228
  redis_client.hset(
@@ -241,7 +241,6 @@ async def list_tasks(
241
  api_key: str = Depends(verify_api_key),
242
  redis_client: redis.Redis = Depends(get_redis_client)
243
  ):
244
- """获取用户所有任务"""
245
  import json
246
  task_ids = redis_client.smembers(f"user:{user_id}:tasks")
247
  tasks = []
@@ -259,10 +258,8 @@ async def list_tasks(
259
 
260
  @app.get("/nodes/health")
261
  async def node_health(api_key: str = Depends(verify_api_key)):
262
- """查看节点健康状态"""
263
  status = []
264
  for node in PROCESSING_NODES:
265
- # 主动触发健康检查(更新最新状态)
266
  await check_node_health(node)
267
  status.append({
268
  "url": node["url"],
 
7
  import time
8
  import os
9
  from typing import Dict, List, Optional
10
+ from redis.connection import ConnectionPool, SSLConnection # 引入SSL连接
11
 
12
  # ---------------------- 1. 初始化配置 ----------------------
13
  app = FastAPI(title="Distributed Agent Controller", version="1.0")
14
 
15
+ # 跨域配置
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+ # ---------------------- 修复Redis连接池(兼容低版本redis-py) ----------------------
25
  def create_redis_pool() -> ConnectionPool:
26
  try:
27
+ # 从环境变量获取配置(必须在Space中正确设置)
28
+ redis_host = os.getenv("REDIS_HOST")
29
+ redis_port = int(os.getenv("REDIS_PORT", 6380))
30
+ redis_password = os.getenv("REDIS_PASSWORD")
31
+
32
+ if not all([redis_host, redis_password]):
33
+ raise ValueError("REDIS_HOST和REDIS_PASSWORD必须配置")
34
+
35
+ # 关键修复:使用SSLConnection类,替代ssl=True参数(兼容低版本redis-py)
36
  pool = ConnectionPool(
37
+ host=redis_host,
38
+ port=redis_port,
39
+ password=redis_password,
40
  decode_responses=True,
41
+ connection_class=SSLConnection, # 强制使用SSL加密连接
42
  socket_timeout=5,
43
+ retry_on_timeout=True,
44
+ max_connections=10
45
  )
46
+
47
+ # 验证连接
48
  with redis.Redis(connection_pool=pool) as client:
49
+ client.ping() # 成功会返回True
50
+ print("Redis连接池初始化成功")
51
  return pool
52
  except Exception as e:
53
+ raise RuntimeError(f"Redis连接池初始化失败:{str(e)}")
54
 
55
+ # 初始化连接池
56
  redis_pool = create_redis_pool()
57
 
58
+ # 依赖注入:获取Redis客户端
59
  def get_redis_client() -> redis.Redis:
 
60
  with redis.Redis(connection_pool=redis_pool) as client:
61
+ yield client
62
 
63
  # ---------------------- 2. API密钥验证 ----------------------
64
  def verify_api_key(api_key: Optional[str] = None):
 
74
  "api_key": os.getenv("NODE_API_KEY", "node-key"),
75
  "health": "unknown",
76
  "last_check": 0,
77
+ "load": 0
78
  },
79
  {
80
  "url": os.getenv("NODE_2_URL", "https://your-username-node2.hf.space"),
 
100
 
101
  # ---------------------- 5. 工具函数 ----------------------
102
  async def check_node_health(node: Dict) -> bool:
 
103
  current_time = time.time()
104
  if current_time - node["last_check"] < 300:
105
  return node["health"] == "alive"
 
121
  return False
122
 
123
  def get_least_loaded_node() -> Optional[Dict]:
 
124
  healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
125
  if not healthy_nodes:
126
  return None
 
129
  def task_key(user_id: str, task_id: str) -> str:
130
  return f"user:{user_id}:task:{task_id}"
131
 
132
+ # ---------------------- 6. 核心接口 ----------------------
133
  @app.post("/task/create")
134
  async def create_task(
135
  req: CreateTaskRequest,
136
  api_key: str = Depends(verify_api_key),
137
+ redis_client: redis.Redis = Depends(get_redis_client)
138
  ):
 
139
  task_id = f"task_{uuid.uuid4().hex[:8]}"
140
  task_data = {
141
  "task_id": task_id,
142
  "name": req.task_name,
143
+ "created": str(int(time.time())),
144
  "updated": str(int(time.time())),
145
+ "history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"')
146
  }
147
 
 
148
  redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
149
  redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
150
 
 
151
  node = get_least_loaded_node()
152
  if not node:
153
  raise HTTPException(status_code=503, detail="No available nodes")
154
 
 
155
  try:
156
  node["load"] += 1
157
  async with httpx.AsyncClient(timeout=60) as client:
 
166
  node["load"] -= 1
167
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
168
 
169
+ import json
 
170
  result = resp.json()
171
+ history = json.loads(task_data["history"])
172
  history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
173
  redis_client.hset(
174
  task_key(req.user_id, task_id),
175
  mapping={
176
+ "history": json.dumps(history),
177
  "updated": str(int(time.time()))
178
  }
179
  )
 
190
  api_key: str = Depends(verify_api_key),
191
  redis_client: redis.Redis = Depends(get_redis_client)
192
  ):
 
193
  import json
194
  key = task_key(req.user_id, req.task_id)
195
  if not redis_client.exists(key):
196
  raise HTTPException(status_code=404, detail="Task not found")
197
 
 
198
  task_data = redis_client.hgetall(key)
199
  history = json.loads(task_data["history"])
200
  history.append({"role": "user", "content": req.message.replace('"', '\\"')})
201
 
 
202
  context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
203
 
 
204
  node = get_least_loaded_node()
205
  if not node:
206
  raise HTTPException(status_code=503, detail="No available nodes")
207
 
 
208
  try:
209
  node["load"] += 1
210
  async with httpx.AsyncClient(timeout=60) as client:
 
223
  node["load"] -= 1
224
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
225
 
 
226
  result = resp.json()
227
  history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
228
  redis_client.hset(
 
241
  api_key: str = Depends(verify_api_key),
242
  redis_client: redis.Redis = Depends(get_redis_client)
243
  ):
 
244
  import json
245
  task_ids = redis_client.smembers(f"user:{user_id}:tasks")
246
  tasks = []
 
258
 
259
  @app.get("/nodes/health")
260
  async def node_health(api_key: str = Depends(verify_api_key)):
 
261
  status = []
262
  for node in PROCESSING_NODES:
 
263
  await check_node_health(node)
264
  status.append({
265
  "url": node["url"],