Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| 虫群节点服务 — 产品版v1.0 | |
| FastAPI后端,支持对话/训练/联邦/权重同步 | |
| """ | |
| import os | |
| import json | |
| import time | |
| import base64 | |
| import numpy as np | |
| from typing import Optional, Dict, List | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| # 路径设置 | |
| _BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| _SRC = os.path.join(_BASE, 'src') | |
| for p in [_BASE, _SRC, os.path.join(_SRC, 'core'), | |
| os.path.join(_SRC, 'chat'), os.path.join(_SRC, 'node'), | |
| os.path.join(_SRC, 'bridge')]: | |
| if p not in os.sys.path: | |
| os.sys.path.insert(0, p) | |
| app = FastAPI(title="虫群产品版", version="v1.0") | |
| # 全局实例(懒加载) | |
| _brain = None | |
| _chat = None | |
| _fed = None | |
| def get_brain(): | |
| global _brain | |
| if _brain is not None: | |
| return _brain | |
| from core.brain import Brain | |
| _brain = Brain() | |
| print('[Node] Brain已初始化') | |
| return _brain | |
| def get_chat(): | |
| global _chat | |
| if _chat is not None: | |
| return _chat | |
| from chat.chat_engine import ChatEngine | |
| api_key = os.environ.get('NIM_API_KEY', '') | |
| # 尝试从文件读取 | |
| if not api_key or api_key == '***': | |
| key_file = os.path.expanduser('~/.swarm/.nim_key') | |
| if os.path.exists(key_file): | |
| with open(key_file) as f: | |
| api_key = f.read().strip() | |
| _chat = ChatEngine(brain=get_brain(), api_key=api_key) | |
| print('[Node] ChatEngine已初始化') | |
| return _chat | |
| # ========== 请求/响应模型 ========== | |
| class ChatRequest(BaseModel): | |
| text: str | |
| teach: Optional[str] = None | |
| class ChatResponse(BaseModel): | |
| text: str | |
| mode: str | |
| ms: int | |
| confidence: float = 0.0 | |
| decoded_words: list = [] | |
| class TrainRequest(BaseModel): | |
| texts: List[str] = [] | |
| epochs: int = 1 | |
| learning_rate: float = 0.01 | |
| class TrainResponse(BaseModel): | |
| status: str | |
| samples: int = 0 | |
| epochs: int = 0 | |
| ms: int = 0 | |
| class WeightsResponse(BaseModel): | |
| area: str | |
| version: int | |
| weights_b64: str | |
| shape: List[int] | |
| class FedAvgRequest(BaseModel): | |
| area: str | |
| weights_b64: str | |
| shape: List[int] | |
| version: int | |
| node_id: str = 'unknown' | |
| class SyncRequest(BaseModel): | |
| areas: List[str] = [] | |
| # ========== 页面路由 ========== | |
| async def index(): | |
| return """<html><head><title>虫群产品版 v1.0</title></head> | |
| <body style="font-family:sans-serif;max-width:800px;margin:40px auto;padding:0 20px"> | |
| <h1>🐛 虫群产品版 v1.0</h1> | |
| <p>类脑Meta Model + 四层路由 + 联邦训练</p> | |
| <h2>API端点</h2> | |
| <ul> | |
| <li><b>POST /chat</b> — 对话</li> | |
| <li><b>POST /train</b> — 训练</li> | |
| <li><b>GET /weights/{area}</b> — 获取权重</li> | |
| <li><b>POST /weights/{area}</b> — 设置权重</li> | |
| <li><b>POST /fedavg</b> — 联邦聚合</li> | |
| <li><b>POST /sync</b> — 同步权重</li> | |
| <li><b>GET /health</b> — 健康检查</li> | |
| <li><b>GET /stats</b> — 统计信息</li> | |
| </ul></body></html>""" | |
| # ========== API路由 ========== | |
| async def health(): | |
| brain_status = 'loaded' if _brain else 'lazy' | |
| chat_status = 'loaded' if _chat else 'lazy' | |
| return { | |
| "status": "ok", "version": "v1.0", | |
| "brain": brain_status, "chat": chat_status, | |
| } | |
| async def chat(req: ChatRequest): | |
| chat_engine = get_chat() | |
| if req.teach: | |
| chat_engine.teach(req.text, req.teach) | |
| return ChatResponse(text=f"已学习: {req.text}", mode="teach", ms=0) | |
| result = chat_engine.chat(req.text) | |
| return ChatResponse( | |
| text=result.get('text', ''), | |
| mode=result.get('mode', 'unknown'), | |
| ms=result.get('ms', 0), | |
| confidence=result.get('confidence', 0), | |
| decoded_words=result.get('decoded_words', []), | |
| ) | |
| async def train(req: TrainRequest): | |
| brain = get_brain() | |
| t0 = time.time() | |
| try: | |
| brain.train(req.texts, epochs=req.epochs, lr=req.learning_rate) | |
| ms = int((time.time() - t0) * 1000) | |
| return TrainResponse( | |
| status='ok', samples=len(req.texts), | |
| epochs=req.epochs, ms=ms, | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_weights(area: str): | |
| brain = get_brain() | |
| if area not in brain.areas: | |
| raise HTTPException(status_code=404, detail=f"区域{area}不存在") | |
| try: | |
| weights = brain.get_area_weights(area) | |
| if weights.size == 0: | |
| return {"area": area, "version": 0, "weights_b64": "", "shape": []} | |
| buf = io_loop(weights) | |
| b64 = base64.b64encode(buf).decode() | |
| return WeightsResponse( | |
| area=area, version=0, | |
| weights_b64=b64, shape=list(weights.shape), | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"权重导出失败: {e}") | |
| async def set_weights(area: str, req: WeightsResponse): | |
| brain = get_brain() | |
| if area not in brain.areas: | |
| raise HTTPException(status_code=404, detail=f"区域{area}不存在") | |
| try: | |
| raw = base64.b64decode(req.weights_b64) | |
| try: | |
| import io | |
| weights = np.load(io.BytesIO(raw), allow_pickle=True)['w'] | |
| except Exception: | |
| weights = np.frombuffer(raw, dtype=np.float32).reshape(req.shape) | |
| brain.set_area_weights(area, weights) | |
| return {"status": "ok", "area": area} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"权重导入失败: {e}") | |
| async def fedavg(req: FedAvgRequest): | |
| brain = get_brain() | |
| if req.area not in brain.areas: | |
| raise HTTPException(status_code=404, detail=f"区域{req.area}不存在") | |
| try: | |
| raw = base64.b64decode(req.weights_b64) | |
| # 尝试npz格式(带压缩) | |
| try: | |
| import io | |
| incoming = np.load(io.BytesIO(raw), allow_pickle=True)['w'] | |
| except Exception: | |
| incoming = np.frombuffer(raw, dtype=np.float32).reshape(req.shape) | |
| result = brain.fedavg(req.area, incoming, req.node_id) | |
| return {"status": "ok", "area": req.area, "result": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"联邦聚合失败: {e}") | |
| async def sync(req: SyncRequest): | |
| brain = get_brain() | |
| areas = req.areas or list(brain.areas.keys()) | |
| result = {} | |
| for area in areas: | |
| if area in brain.areas: | |
| try: | |
| w = brain.get_area_weights(area) | |
| if w.size == 0: | |
| continue | |
| buf = io_loop(w) | |
| result[area] = { | |
| 'weights_b64': base64.b64encode(buf).decode(), | |
| 'shape': list(w.shape), | |
| 'version': 0, | |
| } | |
| except Exception: | |
| continue | |
| return result | |
| async def stats(): | |
| info = {'version': 'v1.0'} | |
| try: | |
| brain = get_brain() | |
| info['brain'] = brain.stats() | |
| except Exception as e: | |
| info['brain_error'] = str(e) | |
| try: | |
| chat = get_chat() | |
| info['chat'] = chat.stats() | |
| except Exception as e: | |
| info['chat_error'] = str(e) | |
| return info | |
| def io_loop(arr: np.ndarray) -> bytes: | |
| """numpy数组→字节""" | |
| import io | |
| buf = io.BytesIO() | |
| np.savez_compressed(buf, w=arr) | |
| return buf.getvalue() | |
| # ========== 启动 ========== | |
| if __name__ == '__main__': | |
| import uvicorn | |
| port = int(os.environ.get('SWARM_PORT', '7860')) | |
| uvicorn.run(app, host='0.0.0.0', port=port) | |