Spaces:
Sleeping
Sleeping
| import asyncio | |
| import websockets | |
| import json | |
| import numpy as np | |
| class WebSocketModelStorage: | |
| def __init__(self, uri): | |
| self.uri = uri | |
| self.websocket = None | |
| async def connect(self): | |
| self.websocket = await websockets.connect(self.uri, max_size=None) | |
| async def disconnect(self): | |
| if self.websocket: | |
| await self.websocket.close() | |
| async def upload_model_chunk(self, model_id, chunk_id, chunk_data): | |
| payload = { | |
| "operation": "vram", | |
| "type": "write", | |
| "block_id": f"{model_id}_{chunk_id}", | |
| "data": chunk_data.tolist() if isinstance(chunk_data, np.ndarray) else chunk_data | |
| } | |
| await self.websocket.send(json.dumps(payload)) | |
| response = await self.websocket.recv() | |
| return json.loads(response) | |
| async def download_model_chunk(self, model_id, chunk_id): | |
| payload = { | |
| "operation": "vram", | |
| "type": "read", | |
| "block_id": f"{model_id}_{chunk_id}" | |
| } | |
| await self.websocket.send(json.dumps(payload)) | |
| response = await self.websocket.recv() | |
| return json.loads(response) | |
| async def upload_model(self, model_id, model_data, chunk_size=1024*1024): # 1MB chunk size | |
| if isinstance(model_data, np.ndarray): | |
| model_data_bytes = model_data.tobytes() | |
| else: | |
| model_data_bytes = model_data.encode("utf-8") # Assuming string data for now | |
| total_size = len(model_data_bytes) | |
| num_chunks = (total_size + chunk_size - 1) // chunk_size | |
| print(f"Uploading model {model_id} in {num_chunks} chunks...") | |
| for i in range(num_chunks): | |
| start = i * chunk_size | |
| end = min((i + 1) * chunk_size, total_size) | |
| chunk = model_data_bytes[start:end] | |
| # Convert chunk to a list of integers for JSON serialization | |
| chunk_list = list(chunk) | |
| response = await self.upload_model_chunk(model_id, i, chunk_list) | |
| if response.get("status") != "success": | |
| print(f"Error uploading chunk {i}: {response.get('message')}") | |
| return False | |
| print(f"Uploaded chunk {i+1}/{num_chunks}") | |
| return True | |
| async def download_model(self, model_id, num_chunks): | |
| print(f"Downloading model {model_id} with {num_chunks} chunks...") | |
| downloaded_chunks = [] | |
| for i in range(num_chunks): | |
| response = await self.download_model_chunk(model_id, i) | |
| if response.get("status") == "success": | |
| downloaded_chunks.append(np.array(response["data"], dtype=np.uint8).tobytes()) | |
| print(f"Downloaded chunk {i+1}/{num_chunks}") | |
| else: | |
| print("Error downloading chunk " + str(i) + ": " + str(response.get("message"))) | |
| return None | |
| # Reconstruct the model from downloaded chunks | |
| full_model_bytes = b"".join(downloaded_chunks) | |
| return np.frombuffer(full_model_bytes, dtype=np.float32) # Assuming original data type was float32 | |
| async def main(): | |
| uri = "ws://localhost:7860/ws" | |
| storage = WebSocketModelStorage(uri) | |
| await storage.connect() | |
| # Example usage: Upload a dummy model | |
| dummy_model_data = np.random.rand(1024 * 1024 * 5).astype(np.float32) # 5MB dummy model | |
| model_id = "test_model_123" | |
| chunk_size = 1024*1024 # Must match the chunk_size in upload_model | |
| total_size = len(dummy_model_data.tobytes()) | |
| num_chunks = (total_size + chunk_size - 1) // chunk_size | |
| success = await storage.upload_model(model_id, dummy_model_data) | |
| if success: | |
| print(f"Model {model_id} uploaded successfully.") | |
| # Test download | |
| downloaded_model = await storage.download_model(model_id, num_chunks) | |
| if downloaded_model is not None: | |
| print(f"Model {model_id} downloaded successfully. Shape: {downloaded_model.shape}") | |
| # Verify integrity (optional, for testing purposes) | |
| if np.array_equal(dummy_model_data, downloaded_model): | |
| print("Downloaded model matches original.") | |
| else: | |
| print("Downloaded model DOES NOT match original.") | |
| else: | |
| print(f"Model {model_id} download failed.") | |
| else: | |
| print(f"Model {model_id} upload failed.") | |
| await storage.disconnect() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |