File size: 4,456 Bytes
43464e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import websockets
import json
import numpy as np

class WebSocketModelStorage:
    def __init__(self, uri):
        self.uri = uri
        self.websocket = None

    async def connect(self):
        self.websocket = await websockets.connect(self.uri, max_size=None)

    async def disconnect(self):
        if self.websocket:
            await self.websocket.close()

    async def upload_model_chunk(self, model_id, chunk_id, chunk_data):
        payload = {
            "operation": "vram",
            "type": "write",
            "block_id": f"{model_id}_{chunk_id}",
            "data": chunk_data.tolist() if isinstance(chunk_data, np.ndarray) else chunk_data
        }
        await self.websocket.send(json.dumps(payload))
        response = await self.websocket.recv()
        return json.loads(response)

    async def download_model_chunk(self, model_id, chunk_id):
        payload = {
            "operation": "vram",
            "type": "read",
            "block_id": f"{model_id}_{chunk_id}"
        }
        await self.websocket.send(json.dumps(payload))
        response = await self.websocket.recv()
        return json.loads(response)

    async def upload_model(self, model_id, model_data, chunk_size=1024*1024): # 1MB chunk size
        if isinstance(model_data, np.ndarray):
            model_data_bytes = model_data.tobytes()
        else:
            model_data_bytes = model_data.encode("utf-8") # Assuming string data for now

        total_size = len(model_data_bytes)
        num_chunks = (total_size + chunk_size - 1) // chunk_size

        print(f"Uploading model {model_id} in {num_chunks} chunks...")

        for i in range(num_chunks):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, total_size)
            chunk = model_data_bytes[start:end]
            
            # Convert chunk to a list of integers for JSON serialization
            chunk_list = list(chunk)

            response = await self.upload_model_chunk(model_id, i, chunk_list)
            if response.get("status") != "success":
                print(f"Error uploading chunk {i}: {response.get('message')}")
                return False
            print(f"Uploaded chunk {i+1}/{num_chunks}")
        return True

    async def download_model(self, model_id, num_chunks):
        print(f"Downloading model {model_id} with {num_chunks} chunks...")
        downloaded_chunks = []
        for i in range(num_chunks):
            response = await self.download_model_chunk(model_id, i)
            if response.get("status") == "success":
                downloaded_chunks.append(np.array(response["data"], dtype=np.uint8).tobytes())
                print(f"Downloaded chunk {i+1}/{num_chunks}")
            else:
                print("Error downloading chunk " + str(i) + ": " + str(response.get("message")))
                return None
        
        # Reconstruct the model from downloaded chunks
        full_model_bytes = b"".join(downloaded_chunks)
        return np.frombuffer(full_model_bytes, dtype=np.float32) # Assuming original data type was float32

async def main():
    uri = "ws://localhost:7860/ws"
    storage = WebSocketModelStorage(uri)
    await storage.connect()

    # Example usage: Upload a dummy model
    dummy_model_data = np.random.rand(1024 * 1024 * 5).astype(np.float32) # 5MB dummy model
    model_id = "test_model_123"
    chunk_size = 1024*1024 # Must match the chunk_size in upload_model
    total_size = len(dummy_model_data.tobytes())
    num_chunks = (total_size + chunk_size - 1) // chunk_size
    success = await storage.upload_model(model_id, dummy_model_data)

    if success:
        print(f"Model {model_id} uploaded successfully.")
        # Test download
        downloaded_model = await storage.download_model(model_id, num_chunks)
        if downloaded_model is not None:
            print(f"Model {model_id} downloaded successfully. Shape: {downloaded_model.shape}")
            # Verify integrity (optional, for testing purposes)
            if np.array_equal(dummy_model_data, downloaded_model):
                print("Downloaded model matches original.")
            else:
                print("Downloaded model DOES NOT match original.")
        else:
            print(f"Model {model_id} download failed.")
    else:
        print(f"Model {model_id} upload failed.")

    await storage.disconnect()

if __name__ == "__main__":
    asyncio.run(main())