| | from typing import List, Dict, Any, Optional
|
| | import asyncio
|
| | import json
|
| | from fastapi import FastAPI, WebSocket
|
| | from pydantic import BaseModel
|
| | import torch
|
| |
|
| | class PeerNetwork:
|
| | def __init__(self, host: str = "localhost", port: int = 8000):
|
| | self.app = FastAPI()
|
| | self.active_peers: Dict[str, WebSocket] = {}
|
| | self.host = host
|
| | self.port = port
|
| |
|
| |
|
| | @self.app.websocket("/ws/{peer_id}")
|
| | async def websocket_endpoint(websocket: WebSocket, peer_id: str):
|
| | await self.connect_peer(websocket, peer_id)
|
| | try:
|
| | while True:
|
| | data = await websocket.receive_text()
|
| | await self.broadcast(data, peer_id)
|
| | except Exception:
|
| | await self.disconnect_peer(peer_id)
|
| |
|
| | async def connect_peer(self, websocket: WebSocket, peer_id: str):
|
| | """Connect a new peer to the network"""
|
| | await websocket.accept()
|
| | self.active_peers[peer_id] = websocket
|
| |
|
| | async def disconnect_peer(self, peer_id: str):
|
| | """Remove a peer from the network"""
|
| | if peer_id in self.active_peers:
|
| | await self.active_peers[peer_id].close()
|
| | del self.active_peers[peer_id]
|
| |
|
| | async def broadcast(self, message: str, sender_id: str):
|
| | """Broadcast a message to all peers except the sender"""
|
| | for peer_id, websocket in self.active_peers.items():
|
| | if peer_id != sender_id:
|
| | await websocket.send_text(message)
|
| |
|
| | class OpenPeerClient:
|
| | def __init__(self, network_url: str):
|
| | self.network_url = network_url
|
| | self.websocket: Optional[WebSocket] = None
|
| | self.peer_id: Optional[str] = None
|
| |
|
| | async def connect(self, peer_id: str):
|
| | """Connect to the peer network"""
|
| | self.peer_id = peer_id
|
| | self.websocket = await WebSocket.connect(f"{self.network_url}/ws/{peer_id}")
|
| |
|
| | async def send_model_update(self, model_state: Dict[str, torch.Tensor]):
|
| | """Send model state updates to the network"""
|
| | if not self.websocket:
|
| | raise RuntimeError("Not connected to network")
|
| |
|
| | serialized_state = {
|
| | "type": "model_update",
|
| | "peer_id": self.peer_id,
|
| | "state": {k: v.cpu().numpy().tolist() for k, v in model_state.items()}
|
| | }
|
| | await self.websocket.send_text(json.dumps(serialized_state))
|
| |
|
| | async def receive_updates(self):
|
| | """Receive updates from the network"""
|
| | if not self.websocket:
|
| | raise RuntimeError("Not connected to network")
|
| |
|
| | while True:
|
| | data = await self.websocket.receive_text()
|
| | yield json.loads(data)
|
| |
|
| | def create_peer_network(host: str = "localhost", port: int = 8000) -> PeerNetwork:
|
| | """Create and start a peer network server"""
|
| | network = PeerNetwork(host, port)
|
| | import uvicorn
|
| | uvicorn.run(network.app, host=host, port=port)
|
| | return network |