INTAI / websocket_model_storage.py
Factor Studios
Upload 36 files
43464e3 verified
import asyncio
import websockets
import json
import numpy as np
class WebSocketModelStorage:
def __init__(self, uri):
self.uri = uri
self.websocket = None
async def connect(self):
self.websocket = await websockets.connect(self.uri, max_size=None)
async def disconnect(self):
if self.websocket:
await self.websocket.close()
async def upload_model_chunk(self, model_id, chunk_id, chunk_data):
payload = {
"operation": "vram",
"type": "write",
"block_id": f"{model_id}_{chunk_id}",
"data": chunk_data.tolist() if isinstance(chunk_data, np.ndarray) else chunk_data
}
await self.websocket.send(json.dumps(payload))
response = await self.websocket.recv()
return json.loads(response)
async def download_model_chunk(self, model_id, chunk_id):
payload = {
"operation": "vram",
"type": "read",
"block_id": f"{model_id}_{chunk_id}"
}
await self.websocket.send(json.dumps(payload))
response = await self.websocket.recv()
return json.loads(response)
async def upload_model(self, model_id, model_data, chunk_size=1024*1024): # 1MB chunk size
if isinstance(model_data, np.ndarray):
model_data_bytes = model_data.tobytes()
else:
model_data_bytes = model_data.encode("utf-8") # Assuming string data for now
total_size = len(model_data_bytes)
num_chunks = (total_size + chunk_size - 1) // chunk_size
print(f"Uploading model {model_id} in {num_chunks} chunks...")
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, total_size)
chunk = model_data_bytes[start:end]
# Convert chunk to a list of integers for JSON serialization
chunk_list = list(chunk)
response = await self.upload_model_chunk(model_id, i, chunk_list)
if response.get("status") != "success":
print(f"Error uploading chunk {i}: {response.get('message')}")
return False
print(f"Uploaded chunk {i+1}/{num_chunks}")
return True
async def download_model(self, model_id, num_chunks):
print(f"Downloading model {model_id} with {num_chunks} chunks...")
downloaded_chunks = []
for i in range(num_chunks):
response = await self.download_model_chunk(model_id, i)
if response.get("status") == "success":
downloaded_chunks.append(np.array(response["data"], dtype=np.uint8).tobytes())
print(f"Downloaded chunk {i+1}/{num_chunks}")
else:
print("Error downloading chunk " + str(i) + ": " + str(response.get("message")))
return None
# Reconstruct the model from downloaded chunks
full_model_bytes = b"".join(downloaded_chunks)
return np.frombuffer(full_model_bytes, dtype=np.float32) # Assuming original data type was float32
async def main():
uri = "ws://localhost:7860/ws"
storage = WebSocketModelStorage(uri)
await storage.connect()
# Example usage: Upload a dummy model
dummy_model_data = np.random.rand(1024 * 1024 * 5).astype(np.float32) # 5MB dummy model
model_id = "test_model_123"
chunk_size = 1024*1024 # Must match the chunk_size in upload_model
total_size = len(dummy_model_data.tobytes())
num_chunks = (total_size + chunk_size - 1) // chunk_size
success = await storage.upload_model(model_id, dummy_model_data)
if success:
print(f"Model {model_id} uploaded successfully.")
# Test download
downloaded_model = await storage.download_model(model_id, num_chunks)
if downloaded_model is not None:
print(f"Model {model_id} downloaded successfully. Shape: {downloaded_model.shape}")
# Verify integrity (optional, for testing purposes)
if np.array_equal(dummy_model_data, downloaded_model):
print("Downloaded model matches original.")
else:
print("Downloaded model DOES NOT match original.")
else:
print(f"Model {model_id} download failed.")
else:
print(f"Model {model_id} upload failed.")
await storage.disconnect()
if __name__ == "__main__":
asyncio.run(main())