import asyncio import websockets import json import numpy as np class WebSocketModelStorage: def __init__(self, uri): self.uri = uri self.websocket = None async def connect(self): self.websocket = await websockets.connect(self.uri, max_size=None) async def disconnect(self): if self.websocket: await self.websocket.close() async def upload_model_chunk(self, model_id, chunk_id, chunk_data): payload = { "operation": "vram", "type": "write", "block_id": f"{model_id}_{chunk_id}", "data": chunk_data.tolist() if isinstance(chunk_data, np.ndarray) else chunk_data } await self.websocket.send(json.dumps(payload)) response = await self.websocket.recv() return json.loads(response) async def download_model_chunk(self, model_id, chunk_id): payload = { "operation": "vram", "type": "read", "block_id": f"{model_id}_{chunk_id}" } await self.websocket.send(json.dumps(payload)) response = await self.websocket.recv() return json.loads(response) async def upload_model(self, model_id, model_data, chunk_size=1024*1024): # 1MB chunk size if isinstance(model_data, np.ndarray): model_data_bytes = model_data.tobytes() else: model_data_bytes = model_data.encode("utf-8") # Assuming string data for now total_size = len(model_data_bytes) num_chunks = (total_size + chunk_size - 1) // chunk_size print(f"Uploading model {model_id} in {num_chunks} chunks...") for i in range(num_chunks): start = i * chunk_size end = min((i + 1) * chunk_size, total_size) chunk = model_data_bytes[start:end] # Convert chunk to a list of integers for JSON serialization chunk_list = list(chunk) response = await self.upload_model_chunk(model_id, i, chunk_list) if response.get("status") != "success": print(f"Error uploading chunk {i}: {response.get('message')}") return False print(f"Uploaded chunk {i+1}/{num_chunks}") return True async def download_model(self, model_id, num_chunks): print(f"Downloading model {model_id} with {num_chunks} chunks...") downloaded_chunks = [] for i in range(num_chunks): response = await self.download_model_chunk(model_id, i) if response.get("status") == "success": downloaded_chunks.append(np.array(response["data"], dtype=np.uint8).tobytes()) print(f"Downloaded chunk {i+1}/{num_chunks}") else: print("Error downloading chunk " + str(i) + ": " + str(response.get("message"))) return None # Reconstruct the model from downloaded chunks full_model_bytes = b"".join(downloaded_chunks) return np.frombuffer(full_model_bytes, dtype=np.float32) # Assuming original data type was float32 async def main(): uri = "ws://localhost:7860/ws" storage = WebSocketModelStorage(uri) await storage.connect() # Example usage: Upload a dummy model dummy_model_data = np.random.rand(1024 * 1024 * 5).astype(np.float32) # 5MB dummy model model_id = "test_model_123" chunk_size = 1024*1024 # Must match the chunk_size in upload_model total_size = len(dummy_model_data.tobytes()) num_chunks = (total_size + chunk_size - 1) // chunk_size success = await storage.upload_model(model_id, dummy_model_data) if success: print(f"Model {model_id} uploaded successfully.") # Test download downloaded_model = await storage.download_model(model_id, num_chunks) if downloaded_model is not None: print(f"Model {model_id} downloaded successfully. Shape: {downloaded_model.shape}") # Verify integrity (optional, for testing purposes) if np.array_equal(dummy_model_data, downloaded_model): print("Downloaded model matches original.") else: print("Downloaded model DOES NOT match original.") else: print(f"Model {model_id} download failed.") else: print(f"Model {model_id} upload failed.") await storage.disconnect() if __name__ == "__main__": asyncio.run(main())