#!/usr/bin/env python3
"""
虫群节点服务 — 产品版v1.0
FastAPI后端,支持对话/训练/联邦/权重同步
"""
import os
import json
import time
import base64
import numpy as np
from typing import Optional, Dict, List
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
# 路径设置
_BASE = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
_SRC = os.path.join(_BASE, 'src')
for p in [_BASE, _SRC, os.path.join(_SRC, 'core'),
os.path.join(_SRC, 'chat'), os.path.join(_SRC, 'node'),
os.path.join(_SRC, 'bridge')]:
if p not in os.sys.path:
os.sys.path.insert(0, p)
app = FastAPI(title="虫群产品版", version="v1.0")
# 全局实例(懒加载)
_brain = None
_chat = None
_fed = None
def get_brain():
global _brain
if _brain is not None:
return _brain
from core.brain import Brain
_brain = Brain()
print('[Node] Brain已初始化')
return _brain
def get_chat():
global _chat
if _chat is not None:
return _chat
from chat.chat_engine import ChatEngine
api_key = os.environ.get('NIM_API_KEY', '')
# 尝试从文件读取
if not api_key or api_key == '***':
key_file = os.path.expanduser('~/.swarm/.nim_key')
if os.path.exists(key_file):
with open(key_file) as f:
api_key = f.read().strip()
_chat = ChatEngine(brain=get_brain(), api_key=api_key)
print('[Node] ChatEngine已初始化')
return _chat
# ========== 请求/响应模型 ==========
class ChatRequest(BaseModel):
text: str
teach: Optional[str] = None
class ChatResponse(BaseModel):
text: str
mode: str
ms: int
confidence: float = 0.0
decoded_words: list = []
class TrainRequest(BaseModel):
texts: List[str] = []
epochs: int = 1
learning_rate: float = 0.01
class TrainResponse(BaseModel):
status: str
samples: int = 0
epochs: int = 0
ms: int = 0
class WeightsResponse(BaseModel):
area: str
version: int
weights_b64: str
shape: List[int]
class FedAvgRequest(BaseModel):
area: str
weights_b64: str
shape: List[int]
version: int
node_id: str = 'unknown'
class SyncRequest(BaseModel):
areas: List[str] = []
# ========== 页面路由 ==========
@app.get("/", response_class=HTMLResponse)
async def index():
return """
虫群产品版 v1.0
🐛 虫群产品版 v1.0
类脑Meta Model + 四层路由 + 联邦训练
API端点
- POST /chat — 对话
- POST /train — 训练
- GET /weights/{area} — 获取权重
- POST /weights/{area} — 设置权重
- POST /fedavg — 联邦聚合
- POST /sync — 同步权重
- GET /health — 健康检查
- GET /stats — 统计信息
"""
# ========== API路由 ==========
@app.get("/health")
async def health():
brain_status = 'loaded' if _brain else 'lazy'
chat_status = 'loaded' if _chat else 'lazy'
return {
"status": "ok", "version": "v1.0",
"brain": brain_status, "chat": chat_status,
}
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
chat_engine = get_chat()
if req.teach:
chat_engine.teach(req.text, req.teach)
return ChatResponse(text=f"已学习: {req.text}", mode="teach", ms=0)
result = chat_engine.chat(req.text)
return ChatResponse(
text=result.get('text', ''),
mode=result.get('mode', 'unknown'),
ms=result.get('ms', 0),
confidence=result.get('confidence', 0),
decoded_words=result.get('decoded_words', []),
)
@app.post("/train", response_model=TrainResponse)
async def train(req: TrainRequest):
brain = get_brain()
t0 = time.time()
try:
brain.train(req.texts, epochs=req.epochs, lr=req.learning_rate)
ms = int((time.time() - t0) * 1000)
return TrainResponse(
status='ok', samples=len(req.texts),
epochs=req.epochs, ms=ms,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/weights/{area}")
async def get_weights(area: str):
brain = get_brain()
if area not in brain.areas:
raise HTTPException(status_code=404, detail=f"区域{area}不存在")
try:
weights = brain.get_area_weights(area)
if weights.size == 0:
return {"area": area, "version": 0, "weights_b64": "", "shape": []}
buf = io_loop(weights)
b64 = base64.b64encode(buf).decode()
return WeightsResponse(
area=area, version=0,
weights_b64=b64, shape=list(weights.shape),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"权重导出失败: {e}")
@app.post("/weights/{area}")
async def set_weights(area: str, req: WeightsResponse):
brain = get_brain()
if area not in brain.areas:
raise HTTPException(status_code=404, detail=f"区域{area}不存在")
try:
raw = base64.b64decode(req.weights_b64)
try:
import io
weights = np.load(io.BytesIO(raw), allow_pickle=True)['w']
except Exception:
weights = np.frombuffer(raw, dtype=np.float32).reshape(req.shape)
brain.set_area_weights(area, weights)
return {"status": "ok", "area": area}
except Exception as e:
raise HTTPException(status_code=500, detail=f"权重导入失败: {e}")
@app.post("/fedavg")
async def fedavg(req: FedAvgRequest):
brain = get_brain()
if req.area not in brain.areas:
raise HTTPException(status_code=404, detail=f"区域{req.area}不存在")
try:
raw = base64.b64decode(req.weights_b64)
# 尝试npz格式(带压缩)
try:
import io
incoming = np.load(io.BytesIO(raw), allow_pickle=True)['w']
except Exception:
incoming = np.frombuffer(raw, dtype=np.float32).reshape(req.shape)
result = brain.fedavg(req.area, incoming, req.node_id)
return {"status": "ok", "area": req.area, "result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"联邦聚合失败: {e}")
@app.post("/sync")
async def sync(req: SyncRequest):
brain = get_brain()
areas = req.areas or list(brain.areas.keys())
result = {}
for area in areas:
if area in brain.areas:
try:
w = brain.get_area_weights(area)
if w.size == 0:
continue
buf = io_loop(w)
result[area] = {
'weights_b64': base64.b64encode(buf).decode(),
'shape': list(w.shape),
'version': 0,
}
except Exception:
continue
return result
@app.get("/stats")
async def stats():
info = {'version': 'v1.0'}
try:
brain = get_brain()
info['brain'] = brain.stats()
except Exception as e:
info['brain_error'] = str(e)
try:
chat = get_chat()
info['chat'] = chat.stats()
except Exception as e:
info['chat_error'] = str(e)
return info
def io_loop(arr: np.ndarray) -> bytes:
"""numpy数组→字节"""
import io
buf = io.BytesIO()
np.savez_compressed(buf, w=arr)
return buf.getvalue()
# ========== 启动 ==========
if __name__ == '__main__':
import uvicorn
port = int(os.environ.get('SWARM_PORT', '7860'))
uvicorn.run(app, host='0.0.0.0', port=port)