#!/usr/bin/env python3 """ 虫群节点服务 — 产品版v1.0 FastAPI后端,支持对话/训练/联邦/权重同步 """ import os import json import time import base64 import numpy as np from typing import Optional, Dict, List from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel # 路径设置 _BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) _SRC = os.path.join(_BASE, 'src') for p in [_BASE, _SRC, os.path.join(_SRC, 'core'), os.path.join(_SRC, 'chat'), os.path.join(_SRC, 'node'), os.path.join(_SRC, 'bridge')]: if p not in os.sys.path: os.sys.path.insert(0, p) app = FastAPI(title="虫群产品版", version="v1.0") # 全局实例(懒加载) _brain = None _chat = None _fed = None def get_brain(): global _brain if _brain is not None: return _brain from core.brain import Brain _brain = Brain() print('[Node] Brain已初始化') return _brain def get_chat(): global _chat if _chat is not None: return _chat from chat.chat_engine import ChatEngine api_key = os.environ.get('NIM_API_KEY', '') # 尝试从文件读取 if not api_key or api_key == '***': key_file = os.path.expanduser('~/.swarm/.nim_key') if os.path.exists(key_file): with open(key_file) as f: api_key = f.read().strip() _chat = ChatEngine(brain=get_brain(), api_key=api_key) print('[Node] ChatEngine已初始化') return _chat # ========== 请求/响应模型 ========== class ChatRequest(BaseModel): text: str teach: Optional[str] = None class ChatResponse(BaseModel): text: str mode: str ms: int confidence: float = 0.0 decoded_words: list = [] class TrainRequest(BaseModel): texts: List[str] = [] epochs: int = 1 learning_rate: float = 0.01 class TrainResponse(BaseModel): status: str samples: int = 0 epochs: int = 0 ms: int = 0 class WeightsResponse(BaseModel): area: str version: int weights_b64: str shape: List[int] class FedAvgRequest(BaseModel): area: str weights_b64: str shape: List[int] version: int node_id: str = 'unknown' class SyncRequest(BaseModel): areas: List[str] = [] # ========== 页面路由 ========== @app.get("/", response_class=HTMLResponse) async def index(): return """虫群产品版 v1.0

🐛 虫群产品版 v1.0

类脑Meta Model + 四层路由 + 联邦训练

API端点

""" # ========== API路由 ========== @app.get("/health") async def health(): brain_status = 'loaded' if _brain else 'lazy' chat_status = 'loaded' if _chat else 'lazy' return { "status": "ok", "version": "v1.0", "brain": brain_status, "chat": chat_status, } @app.post("/chat", response_model=ChatResponse) async def chat(req: ChatRequest): chat_engine = get_chat() if req.teach: chat_engine.teach(req.text, req.teach) return ChatResponse(text=f"已学习: {req.text}", mode="teach", ms=0) result = chat_engine.chat(req.text) return ChatResponse( text=result.get('text', ''), mode=result.get('mode', 'unknown'), ms=result.get('ms', 0), confidence=result.get('confidence', 0), decoded_words=result.get('decoded_words', []), ) @app.post("/train", response_model=TrainResponse) async def train(req: TrainRequest): brain = get_brain() t0 = time.time() try: brain.train(req.texts, epochs=req.epochs, lr=req.learning_rate) ms = int((time.time() - t0) * 1000) return TrainResponse( status='ok', samples=len(req.texts), epochs=req.epochs, ms=ms, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/weights/{area}") async def get_weights(area: str): brain = get_brain() if area not in brain.areas: raise HTTPException(status_code=404, detail=f"区域{area}不存在") try: weights = brain.get_area_weights(area) if weights.size == 0: return {"area": area, "version": 0, "weights_b64": "", "shape": []} buf = io_loop(weights) b64 = base64.b64encode(buf).decode() return WeightsResponse( area=area, version=0, weights_b64=b64, shape=list(weights.shape), ) except Exception as e: raise HTTPException(status_code=500, detail=f"权重导出失败: {e}") @app.post("/weights/{area}") async def set_weights(area: str, req: WeightsResponse): brain = get_brain() if area not in brain.areas: raise HTTPException(status_code=404, detail=f"区域{area}不存在") try: raw = base64.b64decode(req.weights_b64) try: import io weights = np.load(io.BytesIO(raw), allow_pickle=True)['w'] except Exception: weights = np.frombuffer(raw, dtype=np.float32).reshape(req.shape) brain.set_area_weights(area, weights) return {"status": "ok", "area": area} except Exception as e: raise HTTPException(status_code=500, detail=f"权重导入失败: {e}") @app.post("/fedavg") async def fedavg(req: FedAvgRequest): brain = get_brain() if req.area not in brain.areas: raise HTTPException(status_code=404, detail=f"区域{req.area}不存在") try: raw = base64.b64decode(req.weights_b64) # 尝试npz格式(带压缩) try: import io incoming = np.load(io.BytesIO(raw), allow_pickle=True)['w'] except Exception: incoming = np.frombuffer(raw, dtype=np.float32).reshape(req.shape) result = brain.fedavg(req.area, incoming, req.node_id) return {"status": "ok", "area": req.area, "result": result} except Exception as e: raise HTTPException(status_code=500, detail=f"联邦聚合失败: {e}") @app.post("/sync") async def sync(req: SyncRequest): brain = get_brain() areas = req.areas or list(brain.areas.keys()) result = {} for area in areas: if area in brain.areas: try: w = brain.get_area_weights(area) if w.size == 0: continue buf = io_loop(w) result[area] = { 'weights_b64': base64.b64encode(buf).decode(), 'shape': list(w.shape), 'version': 0, } except Exception: continue return result @app.get("/stats") async def stats(): info = {'version': 'v1.0'} try: brain = get_brain() info['brain'] = brain.stats() except Exception as e: info['brain_error'] = str(e) try: chat = get_chat() info['chat'] = chat.stats() except Exception as e: info['chat_error'] = str(e) return info def io_loop(arr: np.ndarray) -> bytes: """numpy数组→字节""" import io buf = io.BytesIO() np.savez_compressed(buf, w=arr) return buf.getvalue() # ========== 启动 ========== if __name__ == '__main__': import uvicorn port = int(os.environ.get('SWARM_PORT', '7860')) uvicorn.run(app, host='0.0.0.0', port=port)