swarm-backend / core /swarm_node.py
lk080424's picture
fix: 添加logging import和logger定义
2286c14 verified
#!/usr/bin/env python3
"""
虫群v8 — 集成核心 (Swarm Integration Core)
三层架构统一入口:
1. 参数化记忆模型 — 数据存参数,模型即数据库
2. 聚合协议 — 临时服务器,按需组网
3. 经济系统 — 算力货币,好友免费,虫皇税收
核心流程:
用户交互 → 记忆即时写入 → 组建临时服务器 → 调度推理 → 计费 → 结果聚合
"""
import hashlib
import logging
import time
import threading
from datetime import datetime
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
from core.parametric_memory import ParametricMemoryModel, MemoryEncoder
from core.aggregation_protocol.protocol import AggregationProtocol
from core.aggregation_protocol.types import (
NodeInfo, NodeCapability, NodeRole, PermissionLevel, AggregationStrategy,
)
from core.aggregation_protocol.economy import SwarmEconomy
class SwarmNode:
"""
虫群节点 — 个人端完整实例
每个用户运行一个SwarmNode,包含:
- 参数化记忆模型:存储个人记忆
- 聚合协议客户端:加入网络、组建临时服务器
- 经济账户:算力交易
这是"虫后"的完整实现
"""
def __init__(
self,
node_id: str,
name: str,
permission_level: PermissionLevel = PermissionLevel.QUEEN,
model_config: str = "tiny",
lora_rank: int = 4,
write_mode: str = "instant",
initial_balance: float = 100.0,
):
self.node_id = node_id
self.name = name
self.permission_level = permission_level
# 1. 参数化记忆模型
self.memory = ParametricMemoryModel(
model_config=model_config,
lora_rank=lora_rank,
write_mode=write_mode,
accumulate_steps=5,
micro_epochs=3,
)
# 2. 聚合协议
cap = NodeCapability(model_types=["memory", "chat"], compute_score=1.0, specializations=["memory", "chat"])
role = NodeRole.HIVE if permission_level == PermissionLevel.OVERMIND else NodeRole.QUEEN
self.protocol = AggregationProtocol(
node_id=node_id,
role=role,
name=name,
permission_level=permission_level,
capabilities=["memory", "chat"],
compute_score=cap.compute_score,
)
# 3. 经济系统
self.economy = SwarmEconomy()
self.economy.register_node(node_id, initial_balance=initial_balance)
# 4. 推理服务
from core.inference_service import InferenceService, DistributedInferenceBridge
self.inference = InferenceService(self.memory, node_id)
self.inference_bridge = DistributedInferenceBridge()
self.inference_bridge.register_service(node_id, self.inference)
# 5. API推理 + MOA聚合
from core.api_inference import APIInferenceManager
from core.moa_aggregator import MOAAggregator, AggregationStrategy, ModelAnswer
self.api_manager = APIInferenceManager.from_env()
self.moa = MOAAggregator(AggregationStrategy.ADAPTIVE)
self._api_models_available = len(self.api_manager.models) > 0
# 统计
self.stats = {
"memories_stored": 0,
"queries_processed": 0,
"remote_calls_made": 0,
"remote_calls_received": 0,
"total_cost_cc": 0.0,
"total_earned_cc": 0.0,
}
# ============================================================
# 生命周期
# ============================================================
def start(self):
"""启动节点"""
self.protocol.start()
print(f"[SwarmNode] {self.name} 启动完成")
def stop(self):
"""停止节点"""
self.memory.save()
print(f"[SwarmNode] {self.name} 已停止")
# ============================================================
# 核心:交互处理
# ============================================================
def chat(self, query: str, user_id: str = "default") -> Dict:
"""
统一交互入口
流程:
1. 记忆即时写入(后台)
2. 本地记忆检索
3. 如果本地不够,组建临时服务器调用远程
4. 计费
5. 返回结果
"""
start_time = time.time()
self.stats["queries_processed"] += 1
# Step 1: 本地记忆检索
local_result = self.memory.recall(query, max_tokens=64, temperature=0.3)
local_confidence = local_result.get("confidence", 0.0)
# Step 2: 本地置信度足够则直接返回
if local_confidence >= 0.7 and len(local_result.get("response", "")) > 5:
elapsed = time.time() - start_time
return {
"response": local_result["response"],
"source": "local_memory",
"confidence": local_confidence,
"latency_ms": elapsed * 1000,
"cost_cc": 0.0,
"node_id": self.node_id,
}
# Step 3: 本地不够,尝试远程调用
remote_result = None
cost = 0.0
try:
online_nodes = self.protocol.discover_nodes()
available = [n for n in online_nodes if n["node_id"] != self.node_id]
if available:
# 提交聚合任务
task_result = self.protocol.submit_task(
query=query,
min_nodes=1,
max_nodes=min(3, len(available)),
strategy=AggregationStrategy.ADAPTIVE_MIX,
timeout_sec=15,
)
if task_result.get("status") == "completed":
remote_result = {
"response": task_result.get("final_response", ""),
"confidence": task_result.get("confidence", 0.5),
"source": "remote_aggregation",
"members": task_result.get("members", []),
}
# 计费
members = task_result.get("members", [])
for member_id in members:
fee = self.economy.call_compute(
caller_id=self.node_id,
provider_id=member_id,
compute_units=100,
)
cost += fee
self.stats["remote_calls_made"] += len(members)
self.stats["total_cost_cc"] += cost
except Exception as e:
print(f"[SwarmNode] 远程调用失败: {e}")
# Step 4: 合并结果
if remote_result and remote_result.get("confidence", 0) > local_confidence:
final_response = remote_result["response"]
final_confidence = remote_result["confidence"]
source = "remote_aggregation"
else:
final_response = local_result.get("response", "")
final_confidence = local_confidence
source = "local_memory"
elapsed = time.time() - start_time
return {
"response": final_response,
"source": source,
"confidence": final_confidence,
"latency_ms": elapsed * 1000,
"cost_cc": cost,
"node_id": self.node_id,
}
def smart_query(self, query: str, use_api: bool = True,
use_moa: bool = True, max_api_models: int = 3) -> Dict:
"""
智能查询:本地记忆 + API推理 + MOA聚合
流程:
1. 本地参数化记忆检索(0ms网络延迟)
2. API多模型推理(智谱GLM/NIM)
3. MOA聚合所有回答 → 最优答案
"""
start_time = time.time()
self.stats["queries_processed"] += 1
from core.moa_aggregator import ModelAnswer, QualityScorer
all_answers = []
# Step 1: 本地记忆检索
local_result = self.memory.recall(query, max_tokens=128, temperature=0.5)
local_resp = local_result.get("response", "")
local_conf = local_result.get("confidence", 0.0)
if local_resp and len(local_resp.strip()) > 3:
all_answers.append(ModelAnswer(
answer=local_resp, model="local_memory",
provider="local", confidence=local_conf,
latency_ms=0, source="memory"
))
# 本地置信度够高且无API需求时直接返回
if local_conf >= 0.8 and not use_api:
elapsed = time.time() - start_time
return {
"response": local_resp, "source": "local_memory",
"confidence": local_conf, "latency_ms": elapsed * 1000,
"cost_cc": 0.0, "contributors": ["local_memory"],
}
# Step 2: API多模型推理
api_answers = []
if use_api and self._api_models_available:
# 用本地记忆作为上下文提示
context = ""
if local_resp:
context = f"已知信息:{local_resp}\n请结合以上信息回答问题。"
api_results = self.api_manager.infer_multi(
query, context=context, max_models=max_api_models
)
for r in api_results:
if r.success and r.answer:
api_answers.append(ModelAnswer(
answer=r.answer, model=r.model,
provider=r.provider, confidence=0.7,
latency_ms=r.latency_ms, source="api"
))
self.stats["api_calls_made"] = self.stats.get("api_calls_made", 0) + 1
all_answers.extend(api_answers)
# 智能过滤:当API回答可用时,本地记忆仅作上下文不参与聚合
if api_answers and all_answers:
filtered = [a for a in all_answers if a.source != "memory"]
if filtered: # 有API回答则只保留API
# 本地记忆作为上下文已在API调用时注入
all_answers = filtered
logger.info(f"MOA过滤: 移除local_memory,保留{len(filtered)}个API候选")
# Step 3: 分布式推理(如果联网)
remote_answers = []
try:
online_nodes = self.protocol.discover_nodes()
available = [n for n in online_nodes if n["node_id"] != self.node_id]
if available:
for node_info in available[:2]:
svc = self.inference_bridge.services.get(node_info["node_id"])
if svc:
try:
r = svc.infer(query, max_tokens=128)
if r.get("response"):
remote_answers.append(ModelAnswer(
answer=r["response"],
model=node_info["node_id"],
provider="swarm", confidence=r.get("confidence", 0.5),
latency_ms=r.get("latency_ms", 0), source="swarm"
))
except:
pass
except:
pass
all_answers.extend(remote_answers)
# Step 4: MOA聚合
if use_moa and len(all_answers) > 1:
moa_result = self.moa.aggregate(all_answers, question=query)
final_answer = moa_result.final_answer
final_confidence = moa_result.confidence
source = f"moa({moa_result.strategy})"
contributors = moa_result.contributors
elif all_answers:
best = max(all_answers, key=lambda a: a.confidence)
final_answer = best.answer
final_confidence = best.confidence
source = best.source
contributors = [best.model]
else:
final_answer = ""
final_confidence = 0.0
source = "none"
contributors = []
elapsed = time.time() - start_time
return {
"response": final_answer,
"source": source,
"confidence": final_confidence,
"latency_ms": elapsed * 1000,
"cost_cc": 0.0,
"contributors": contributors,
"detail": {
"local": 1 if local_resp else 0,
"api": len(api_answers),
"swarm": len(remote_answers),
"total_candidates": len(all_answers),
}
}
# ============================================================
# 记忆操作
# ============================================================
def store_memory(self, user_input: str, ai_response: str,
memory_type: str = "chat", importance: float = 0.5) -> str:
"""存储交互记忆"""
mid = self.memory.store(user_input, ai_response,
memory_type=memory_type, importance=importance)
self.stats["memories_stored"] += 1
return mid
def recall_memory(self, query: str, max_tokens: int = 64) -> Dict:
"""检索记忆"""
return self.memory.recall(query, max_tokens=max_tokens)
# ============================================================
# 社交与经济
# ============================================================
def add_friend(self, friend_id: str):
"""添加好友(双方免费调用)"""
self.economy.add_friend(self.node_id, friend_id)
self.economy.add_friend(friend_id, self.node_id)
def join_circle(self, circle_name: str, member_ids: List[str]):
"""加入圈子(圈内免费/折扣)"""
self.economy.create_circle(circle_name, self.node_id)
for mid in member_ids:
self.economy.join_circle(circle_name, mid)
def list_compute(self, units: int, price_per_unit: float,
friend_free: bool = True):
"""挂出算力出售"""
from core.aggregation_protocol.economy_types import ResourceType
self.economy.sell_compute(
provider_id=self.node_id,
resource=ResourceType.INFERENCE,
units=units,
price=price_per_unit,
friend_free=friend_free,
)
def get_balance(self) -> float:
"""查询余额"""
acc = self.economy.currency.get_account(self.node_id)
return acc.balance if acc else 0.0
# ============================================================
# 网络操作
# ============================================================
def connect_to(self, other_node: 'SwarmNode'):
"""连接到另一个节点(本地模拟)"""
# 注册对方节点到协议
other_info = other_node.protocol.node_info
self.protocol.add_remote_node(
node_id=other_info.node_id,
name=other_info.name,
permission_level=other_info.permission_level,
capabilities=other_info.capability.specializations,
compute_score=other_info.capability.compute_score,
)
# 注册经济账户
try:
bal = self.economy.currency.get_balance(other_node.node_id)
except:
bal = None
if not bal:
self.economy.register_node(other_node.node_id, initial_balance=100.0)
def discover_network(self) -> List[Dict]:
"""发现网络中的节点"""
return self.protocol.discover_nodes()
# ============================================================
# 状态
# ============================================================
def get_status(self) -> Dict:
"""获取节点完整状态"""
return {
"node_id": self.node_id,
"name": self.name,
"permission": self.permission_level.value,
"memory": self.memory.get_status(),
"network": self.protocol.get_status(),
"balance": self.get_balance(),
"stats": self.stats,
}
class SwarmNetwork:
"""
虫群网络 — 多节点模拟/管理
用于本地测试和演示
"""
def __init__(self):
self.nodes: Dict[str, SwarmNode] = {}
self.economy = SwarmEconomy()
def create_node(self, node_id: str, name: str,
permission: PermissionLevel = PermissionLevel.QUEEN,
model_config: str = "tiny",
initial_balance: float = 100.0) -> SwarmNode:
"""创建并注册一个节点"""
node = SwarmNode(
node_id=node_id,
name=name,
permission_level=permission,
model_config=model_config,
initial_balance=initial_balance,
)
node.start()
self.nodes[node_id] = node
return node
def connect_all(self):
"""所有节点互相连接"""
node_list = list(self.nodes.values())
for i, node_a in enumerate(node_list):
for j, node_b in enumerate(node_list):
if i != j:
node_a.connect_to(node_b)
def add_friendship(self, id_a: str, id_b: str):
"""建立好友关系"""
self.nodes[id_a].add_friend(id_b)
def create_circle(self, name: str, member_ids: List[str]):
"""创建圈子"""
creator = member_ids[0] if member_ids else ""
self.economy.create_circle(name, creator)
for mid in member_ids:
self.economy.join_circle(name, mid)
# 同步到各节点的economy
if mid in self.nodes:
self.nodes[mid].economy.create_circle(name, mid)
for other_id in member_ids:
if other_id != mid:
self.nodes[mid].economy.join_circle(name, other_id)
def get_network_status(self) -> Dict:
"""获取全网状态"""
return {
"total_nodes": len(self.nodes),
"nodes": {nid: node.get_status() for nid, node in self.nodes.items()},
}