fiewolf1000 commited on
Commit
060bedc
·
verified ·
1 Parent(s): 54b5e86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -36
app.py CHANGED
@@ -7,6 +7,7 @@ import redis
7
  import time
8
  import os
9
  from typing import Dict, List, Optional
 
10
 
11
  # ---------------------- 1. 初始化配置 ----------------------
12
  app = FastAPI(title="Distributed Agent Controller", version="1.0")
@@ -14,36 +15,48 @@ app = FastAPI(title="Distributed Agent Controller", version="1.0")
14
  # 跨域配置(适配Hugging Face Space环境)
15
  app.add_middleware(
16
  CORSMiddleware,
17
- allow_origins=["*"], # Space环境下允许跨域调用
18
  allow_credentials=True,
19
  allow_methods=["*"],
20
  allow_headers=["*"],
21
  )
22
 
23
- # Redis连接(必须使用公网Redis,Space内无本地Redis)
24
- def get_redis_client():
 
25
  try:
26
- return redis.Redis(
27
  host=os.getenv("REDIS_HOST"),
28
  port=int(os.getenv("REDIS_PORT", 6379)),
29
  password=os.getenv("REDIS_PASSWORD"),
30
  decode_responses=True,
31
  socket_timeout=5,
32
- retry_on_timeout=True
 
33
  )
 
 
 
 
34
  except Exception as e:
35
- raise RuntimeError(f"Redis connection failed: {str(e)}")
36
 
37
- redis_client = get_redis_client()
 
38
 
39
- # API密钥验证
 
 
 
 
 
40
  def verify_api_key(api_key: Optional[str] = None):
41
  valid_key = os.getenv("CONTROLLER_API_KEY", "default-key")
42
  if api_key != valid_key:
43
  raise HTTPException(status_code=401, detail="Invalid API key")
44
  return api_key
45
 
46
- # 处理节点配置(从环境变量读取,支持动态扩展)
47
  PROCESSING_NODES = [
48
  {
49
  "url": os.getenv("NODE_1_URL", "https://your-username-node1.hf.space"),
@@ -61,7 +74,7 @@ PROCESSING_NODES = [
61
  }
62
  ]
63
 
64
- # ---------------------- 2. 数据模型 ----------------------
65
  class CreateTaskRequest(BaseModel):
66
  user_id: str
67
  task_name: str
@@ -74,7 +87,7 @@ class SendMessageRequest(BaseModel):
74
  max_new_tokens: int = 512
75
  temperature: float = 0.7
76
 
77
- # ---------------------- 3. 工具函数 ----------------------
78
  async def check_node_health(node: Dict) -> bool:
79
  """检查节点健康状态,5分钟缓存结果"""
80
  current_time = time.time()
@@ -93,35 +106,38 @@ async def check_node_health(node: Dict) -> bool:
93
  return True
94
  node["health"] = "dead"
95
  return False
96
- except:
97
  node["health"] = "dead"
98
  return False
99
 
100
  def get_least_loaded_node() -> Optional[Dict]:
101
- """选择负载最低的健康节点(优化负载均衡)"""
102
  healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
103
  if not healthy_nodes:
104
  return None
105
- # 按当前负载排序,选择负载最低的
106
  return min(healthy_nodes, key=lambda x: x["load"])
107
 
108
  def task_key(user_id: str, task_id: str) -> str:
109
  return f"user:{user_id}:task:{task_id}"
110
 
111
- # ---------------------- 4. 核心接口 ----------------------
112
  @app.post("/task/create")
113
- async def create_task(req: CreateTaskRequest, api_key: str = Depends(verify_api_key)):
 
 
 
 
114
  """创建新任务并生成初始响应"""
115
  task_id = f"task_{uuid.uuid4().hex[:8]}"
116
  task_data = {
117
  "task_id": task_id,
118
  "name": req.task_name,
119
- "created": int(time.time()),
120
- "updated": int(time.time()),
121
- "history": [{"role": "user", "content": req.initial_prompt}]
122
  }
123
 
124
- # 保存任务到Redis
125
  redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
126
  redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
127
 
@@ -132,26 +148,30 @@ async def create_task(req: CreateTaskRequest, api_key: str = Depends(verify_api_
132
 
133
  # 调用节点生成代码
134
  try:
135
- node["load"] += 1 # 增加负载计数
136
  async with httpx.AsyncClient(timeout=60) as client:
137
  resp = await client.post(
138
  f"{node['url']}/generate/code",
139
  json={"prompt": req.initial_prompt, "max_new_tokens": 512},
140
  params={"api_key": node["api_key"]}
141
  )
142
- node["load"] -= 1 # 减少负载计数
143
  resp.raise_for_status()
144
  except Exception as e:
145
  node["load"] -= 1
146
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
147
 
148
- # 更新任务历史
 
149
  result = resp.json()
150
- history = eval(task_data["history"]) # 从字符串恢复列表
151
- history.append({"role": "assistant", "content": result["code"]})
152
  redis_client.hset(
153
  task_key(req.user_id, task_id),
154
- mapping={"history": str(history), "updated": int(time.time())}
 
 
 
155
  )
156
 
157
  return {
@@ -161,16 +181,21 @@ async def create_task(req: CreateTaskRequest, api_key: str = Depends(verify_api_
161
  }
162
 
163
  @app.post("/task/message")
164
- async def send_message(req: SendMessageRequest, api_key: str = Depends(verify_api_key)):
 
 
 
 
165
  """向现有任务发送消息"""
 
166
  key = task_key(req.user_id, req.task_id)
167
  if not redis_client.exists(key):
168
  raise HTTPException(status_code=404, detail="Task not found")
169
 
170
- # 获取历史对话
171
  task_data = redis_client.hgetall(key)
172
- history = eval(task_data["history"])
173
- history.append({"role": "user", "content": req.message})
174
 
175
  # 构建上下文
176
  context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
@@ -201,17 +226,25 @@ async def send_message(req: SendMessageRequest, api_key: str = Depends(verify_ap
201
 
202
  # 更新历史
203
  result = resp.json()
204
- history.append({"role": "assistant", "content": result["code"]})
205
  redis_client.hset(
206
  key,
207
- mapping={"history": str(history), "updated": int(time.time())}
 
 
 
208
  )
209
 
210
  return {"status": "success", "response": result["code"]}
211
 
212
  @app.get("/user/tasks/{user_id}")
213
- async def list_tasks(user_id: str, api_key: str = Depends(verify_api_key)):
 
 
 
 
214
  """获取用户所有任务"""
 
215
  task_ids = redis_client.smembers(f"user:{user_id}:tasks")
216
  tasks = []
217
  for tid in task_ids:
@@ -222,7 +255,7 @@ async def list_tasks(user_id: str, api_key: str = Depends(verify_api_key)):
222
  "name": data["name"],
223
  "created": time.ctime(int(data["created"])),
224
  "updated": time.ctime(int(data["updated"])),
225
- "history_length": len(eval(data["history"]))
226
  })
227
  return {"status": "success", "tasks": tasks}
228
 
@@ -231,11 +264,12 @@ async def node_health(api_key: str = Depends(verify_api_key)):
231
  """查看节点健康状态"""
232
  status = []
233
  for node in PROCESSING_NODES:
 
 
234
  status.append({
235
  "url": node["url"],
236
  "health": node["health"],
237
  "load": node["load"],
238
  "last_check": time.ctime(node["last_check"]) if node["last_check"] else "Never"
239
  })
240
- return {"status": "success", "nodes": status}
241
-
 
7
  import time
8
  import os
9
  from typing import Dict, List, Optional
10
+ from redis.connection import ConnectionPool # 引入连接池
11
 
12
  # ---------------------- 1. 初始化配置 ----------------------
13
  app = FastAPI(title="Distributed Agent Controller", version="1.0")
 
15
  # 跨域配置(适配Hugging Face Space环境)
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"], # 生产环境替换为具体域名
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+ # ---------------------- 关键修复:Redis连接池 + 依赖注入 ----------------------
25
+ def create_redis_pool() -> ConnectionPool:
26
+ """创建Redis连接池(替代全局客户端,避免析构问题)"""
27
  try:
28
+ pool = ConnectionPool(
29
  host=os.getenv("REDIS_HOST"),
30
  port=int(os.getenv("REDIS_PORT", 6379)),
31
  password=os.getenv("REDIS_PASSWORD"),
32
  decode_responses=True,
33
  socket_timeout=5,
34
+ retry_on_timeout=True,
35
+ max_connections=10 # 限制连接数,避免资源泄漏
36
  )
37
+ # 验证连接池可用性
38
+ with redis.Redis(connection_pool=pool) as client:
39
+ client.ping()
40
+ return pool
41
  except Exception as e:
42
+ raise RuntimeError(f"Redis connection pool failed: {str(e)}")
43
 
44
+ # 初始化连接池(全局仅创建一次)
45
+ redis_pool = create_redis_pool()
46
 
47
+ def get_redis_client() -> redis.Redis:
48
+ """依赖注入:每次请求创建独立客户端,请求结束自动关闭"""
49
+ with redis.Redis(connection_pool=redis_pool) as client:
50
+ yield client # 用yield确保客户端使用后自动释放
51
+
52
+ # ---------------------- 2. API密钥验证 ----------------------
53
  def verify_api_key(api_key: Optional[str] = None):
54
  valid_key = os.getenv("CONTROLLER_API_KEY", "default-key")
55
  if api_key != valid_key:
56
  raise HTTPException(status_code=401, detail="Invalid API key")
57
  return api_key
58
 
59
+ # ---------------------- 3. 处理节点配置 ----------------------
60
  PROCESSING_NODES = [
61
  {
62
  "url": os.getenv("NODE_1_URL", "https://your-username-node1.hf.space"),
 
74
  }
75
  ]
76
 
77
+ # ---------------------- 4. 数据模型 ----------------------
78
  class CreateTaskRequest(BaseModel):
79
  user_id: str
80
  task_name: str
 
87
  max_new_tokens: int = 512
88
  temperature: float = 0.7
89
 
90
+ # ---------------------- 5. 工具函数 ----------------------
91
  async def check_node_health(node: Dict) -> bool:
92
  """检查节点健康状态,5分钟缓存结果"""
93
  current_time = time.time()
 
106
  return True
107
  node["health"] = "dead"
108
  return False
109
+ except Exception as e:
110
  node["health"] = "dead"
111
  return False
112
 
113
  def get_least_loaded_node() -> Optional[Dict]:
114
+ """选择负载最低的健康节点"""
115
  healthy_nodes = [n for n in PROCESSING_NODES if check_node_health(n)]
116
  if not healthy_nodes:
117
  return None
 
118
  return min(healthy_nodes, key=lambda x: x["load"])
119
 
120
  def task_key(user_id: str, task_id: str) -> str:
121
  return f"user:{user_id}:task:{task_id}"
122
 
123
+ # ---------------------- 6. 核心接口(修复Redis调用逻辑) ----------------------
124
  @app.post("/task/create")
125
+ async def create_task(
126
+ req: CreateTaskRequest,
127
+ api_key: str = Depends(verify_api_key),
128
+ redis_client: redis.Redis = Depends(get_redis_client) # 依赖注入Redis客户端
129
+ ):
130
  """创建新任务并生成初始响应"""
131
  task_id = f"task_{uuid.uuid4().hex[:8]}"
132
  task_data = {
133
  "task_id": task_id,
134
  "name": req.task_name,
135
+ "created": str(int(time.time())), # 存储为字符串,避免Redis哈希类型转换问题
136
+ "updated": str(int(time.time())),
137
+ "history": '[{"role": "user", "content": "%s"}]' % req.initial_prompt.replace('"', '\\"') # 转义双引号,避免JSON错误
138
  }
139
 
140
+ # 保存任务到Redis(使用注入的客户端)
141
  redis_client.hset(task_key(req.user_id, task_id), mapping=task_data)
142
  redis_client.sadd(f"user:{req.user_id}:tasks", task_id)
143
 
 
148
 
149
  # 调用节点生成代码
150
  try:
151
+ node["load"] += 1
152
  async with httpx.AsyncClient(timeout=60) as client:
153
  resp = await client.post(
154
  f"{node['url']}/generate/code",
155
  json={"prompt": req.initial_prompt, "max_new_tokens": 512},
156
  params={"api_key": node["api_key"]}
157
  )
158
+ node["load"] -= 1
159
  resp.raise_for_status()
160
  except Exception as e:
161
  node["load"] -= 1
162
  raise HTTPException(status_code=500, detail=f"Node error: {str(e)}")
163
 
164
+ # 更新任务历史(修复eval风险,用json.loads替代)
165
+ import json # 建议在文件顶部导入,此处为了突出修复点
166
  result = resp.json()
167
+ history = json.loads(task_data["history"]) # 用json.loads替代eval,更安全
168
+ history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
169
  redis_client.hset(
170
  task_key(req.user_id, task_id),
171
+ mapping={
172
+ "history": json.dumps(history), # 用json.dumps存储,避免格式错误
173
+ "updated": str(int(time.time()))
174
+ }
175
  )
176
 
177
  return {
 
181
  }
182
 
183
  @app.post("/task/message")
184
+ async def send_message(
185
+ req: SendMessageRequest,
186
+ api_key: str = Depends(verify_api_key),
187
+ redis_client: redis.Redis = Depends(get_redis_client)
188
+ ):
189
  """向现有任务发送消息"""
190
+ import json
191
  key = task_key(req.user_id, req.task_id)
192
  if not redis_client.exists(key):
193
  raise HTTPException(status_code=404, detail="Task not found")
194
 
195
+ # 获取历史对话(用json.loads替代eval)
196
  task_data = redis_client.hgetall(key)
197
+ history = json.loads(task_data["history"])
198
+ history.append({"role": "user", "content": req.message.replace('"', '\\"')})
199
 
200
  # 构建上下文
201
  context = "\n".join([f"{m['role']}: {m['content']}" for m in history])
 
226
 
227
  # 更新历史
228
  result = resp.json()
229
+ history.append({"role": "assistant", "content": result["code"].replace('"', '\\"')})
230
  redis_client.hset(
231
  key,
232
+ mapping={
233
+ "history": json.dumps(history),
234
+ "updated": str(int(time.time()))
235
+ }
236
  )
237
 
238
  return {"status": "success", "response": result["code"]}
239
 
240
  @app.get("/user/tasks/{user_id}")
241
+ async def list_tasks(
242
+ user_id: str,
243
+ api_key: str = Depends(verify_api_key),
244
+ redis_client: redis.Redis = Depends(get_redis_client)
245
+ ):
246
  """获取用户所有任务"""
247
+ import json
248
  task_ids = redis_client.smembers(f"user:{user_id}:tasks")
249
  tasks = []
250
  for tid in task_ids:
 
255
  "name": data["name"],
256
  "created": time.ctime(int(data["created"])),
257
  "updated": time.ctime(int(data["updated"])),
258
+ "history_length": len(json.loads(data["history"]))
259
  })
260
  return {"status": "success", "tasks": tasks}
261
 
 
264
  """查看节点健康状态"""
265
  status = []
266
  for node in PROCESSING_NODES:
267
+ # 主动触发健康检查(更新最新状态)
268
+ await check_node_health(node)
269
  status.append({
270
  "url": node["url"],
271
  "health": node["health"],
272
  "load": node["load"],
273
  "last_check": time.ctime(node["last_check"]) if node["last_check"] else "Never"
274
  })
275
+ return {"status": "success", "nodes": status}