|
|
"""
|
|
|
Multithreaded remote storage implementation for virtual GPU.
|
|
|
Provides thread-safe distributed storage with HuggingFace and DuckDB backend.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
from tensor_storage import TensorStorage
|
|
|
from config import DB_URL, get_hf_token_cached
|
|
|
import json
|
|
|
import numpy as np
|
|
|
from typing import Dict, Any, Optional, Union, List
|
|
|
import threading
|
|
|
import time
|
|
|
import hashlib
|
|
|
import logging
|
|
|
import uuid
|
|
|
import duckdb
|
|
|
import queue
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from dataclasses import dataclass
|
|
|
import asyncio
|
|
|
|
|
|
@dataclass
|
|
|
class StorageBlock:
|
|
|
"""Represents a block of storage with thread ownership"""
|
|
|
block_id: str
|
|
|
size: int
|
|
|
thread_id: Optional[int] = None
|
|
|
last_accessed: float = 0.0
|
|
|
is_locked: bool = False
|
|
|
data: Any = None
|
|
|
|
|
|
class ConnectionPool:
|
|
|
"""Manages a pool of DuckDB connections for multiple threads"""
|
|
|
def __init__(self, db_url: str, max_connections: int = 32):
|
|
|
self.db_url = db_url
|
|
|
self.max_connections = max_connections
|
|
|
self.connections = queue.Queue(maxsize=max_connections)
|
|
|
self.connection_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
self.hf_token = get_hf_token_cached()
|
|
|
|
|
|
self._initialize_connections()
|
|
|
|
|
|
def _initialize_connections(self):
|
|
|
"""Initialize connection pool"""
|
|
|
for _ in range(self.max_connections):
|
|
|
conn = self._create_connection()
|
|
|
self.connections.put(conn)
|
|
|
|
|
|
def _create_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Create a new DuckDB connection"""
|
|
|
conn = duckdb.connect(":memory:")
|
|
|
conn.execute("""
|
|
|
INSTALL json;
|
|
|
LOAD json;
|
|
|
INSTALL httpfs;
|
|
|
LOAD httpfs;
|
|
|
SET s3_endpoint='hf.co';
|
|
|
SET s3_use_ssl=true;
|
|
|
SET s3_url_style='path';
|
|
|
""")
|
|
|
|
|
|
conn.execute(f"SET s3_access_key_id='{self.hf_token}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.hf_token}';")
|
|
|
return conn
|
|
|
|
|
|
def get_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Get a connection from the pool"""
|
|
|
try:
|
|
|
return self.connections.get(timeout=5)
|
|
|
except queue.Empty:
|
|
|
with self.connection_lock:
|
|
|
if self.connections.qsize() < self.max_connections:
|
|
|
return self._create_connection()
|
|
|
else:
|
|
|
return self.connections.get()
|
|
|
|
|
|
def return_connection(self, conn: duckdb.DuckDBPyConnection):
|
|
|
"""Return a connection to the pool"""
|
|
|
try:
|
|
|
self.connections.put(conn, timeout=1)
|
|
|
except queue.Full:
|
|
|
conn.close()
|
|
|
|
|
|
class MultithreadStorage(TensorStorage):
|
|
|
"""
|
|
|
Thread-safe remote storage implementation using DuckDB and HuggingFace.
|
|
|
Optimized for high-concurrency access from multiple threads.
|
|
|
Inherits tensor operations from TensorStorage.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, db_url: str = DB_URL, max_connections: int = 32):
|
|
|
|
|
|
self.connection_pool = ConnectionPool(db_url, max_connections)
|
|
|
|
|
|
|
|
|
self.storage_id = hashlib.md5(db_url.encode()).hexdigest()[:8]
|
|
|
self.thread_locks: Dict[int, threading.Lock] = {}
|
|
|
self.global_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
self.blocks: Dict[str, StorageBlock] = {}
|
|
|
self.block_locks: Dict[str, threading.Lock] = {}
|
|
|
|
|
|
|
|
|
self.read_executor = ThreadPoolExecutor(max_workers=16, thread_name_prefix="read")
|
|
|
self.write_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="write")
|
|
|
|
|
|
|
|
|
self.stats = {
|
|
|
'total_size': 0,
|
|
|
'available_size': float('inf'),
|
|
|
'model_count': 0,
|
|
|
'tensor_count': 0,
|
|
|
'active_threads': set(),
|
|
|
'thread_ops': {}
|
|
|
}
|
|
|
|
|
|
|
|
|
self._init_database()
|
|
|
|
|
|
def _init_database(self):
|
|
|
"""Initialize database schema"""
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
|
|
|
conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS storage_blocks (
|
|
|
block_id VARCHAR PRIMARY KEY,
|
|
|
size BIGINT,
|
|
|
thread_id INTEGER,
|
|
|
last_accessed TIMESTAMP,
|
|
|
is_locked BOOLEAN,
|
|
|
data BLOB
|
|
|
);
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS thread_stats (
|
|
|
thread_id INTEGER PRIMARY KEY,
|
|
|
ops_count BIGINT,
|
|
|
total_bytes BIGINT,
|
|
|
last_active TIMESTAMP
|
|
|
);
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS tensors (
|
|
|
tensor_id VARCHAR PRIMARY KEY,
|
|
|
shape VARCHAR,
|
|
|
dtype VARCHAR,
|
|
|
block_id VARCHAR,
|
|
|
thread_id INTEGER,
|
|
|
FOREIGN KEY(block_id) REFERENCES storage_blocks(block_id)
|
|
|
);
|
|
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_blocks_thread
|
|
|
ON storage_blocks(thread_id);
|
|
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_tensors_thread
|
|
|
ON tensors(thread_id);
|
|
|
""")
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
async def store_tensor(self, tensor_id: str, data: np.ndarray, metadata: Dict[str, Any], thread_id: Optional[int] = None) -> bool:
|
|
|
"""
|
|
|
Store tensor data with thread awareness.
|
|
|
Uses async IO for better concurrency.
|
|
|
"""
|
|
|
|
|
|
block_id = f"block_{self.storage_id}_{uuid.uuid4().hex[:8]}"
|
|
|
|
|
|
|
|
|
thread_lock = self._get_thread_lock(thread_id)
|
|
|
|
|
|
async with thread_lock:
|
|
|
try:
|
|
|
|
|
|
future = self.write_executor.submit(
|
|
|
self._store_tensor_data,
|
|
|
block_id,
|
|
|
tensor_id,
|
|
|
data,
|
|
|
metadata,
|
|
|
thread_id
|
|
|
)
|
|
|
|
|
|
|
|
|
result = await asyncio.wrap_future(future)
|
|
|
|
|
|
if result:
|
|
|
|
|
|
self._update_thread_stats(thread_id, len(data.tobytes()))
|
|
|
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error storing tensor {tensor_id} for thread {thread_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def _store_tensor_data(self, block_id: str, tensor_id: str, data: np.ndarray, metadata: Dict[str, Any], thread_id: Optional[int]) -> bool:
|
|
|
"""Internal method to store tensor data"""
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
|
|
|
block = StorageBlock(
|
|
|
block_id=block_id,
|
|
|
size=len(data.tobytes()),
|
|
|
thread_id=thread_id,
|
|
|
last_accessed=time.time(),
|
|
|
data=data
|
|
|
)
|
|
|
|
|
|
|
|
|
conn.execute("""
|
|
|
INSERT INTO storage_blocks (block_id, size, thread_id, last_accessed, is_locked, data)
|
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
|
""", (block.block_id, block.size, block.thread_id, block.last_accessed, block.is_locked, data.tobytes()))
|
|
|
|
|
|
|
|
|
conn.execute("""
|
|
|
INSERT INTO tensors (tensor_id, shape, dtype, block_id, thread_id)
|
|
|
VALUES (?, ?, ?, ?, ?)
|
|
|
""", (tensor_id, str(data.shape), str(data.dtype), block_id, thread_id))
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Database error storing tensor {tensor_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
async def get_tensor(self, tensor_id: str, thread_id: Optional[int] = None) -> Optional[np.ndarray]:
|
|
|
"""Retrieve tensor data with thread awareness"""
|
|
|
thread_lock = self._get_thread_lock(thread_id)
|
|
|
|
|
|
async with thread_lock:
|
|
|
try:
|
|
|
|
|
|
future = self.read_executor.submit(
|
|
|
self._get_tensor_data,
|
|
|
tensor_id,
|
|
|
thread_id
|
|
|
)
|
|
|
|
|
|
|
|
|
return await asyncio.wrap_future(future)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error retrieving tensor {tensor_id} for thread {thread_id}: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def _get_tensor_data(self, tensor_id: str, thread_id: Optional[int]) -> Optional[np.ndarray]:
|
|
|
"""Internal method to retrieve tensor data"""
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
|
|
|
result = conn.execute("""
|
|
|
SELECT b.data, b.block_id, t.shape, t.dtype
|
|
|
FROM tensors t
|
|
|
JOIN storage_blocks b ON t.block_id = b.block_id
|
|
|
WHERE t.tensor_id = ?
|
|
|
""", [tensor_id]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
return None
|
|
|
|
|
|
data_bytes, block_id, shape_str, dtype_str = result
|
|
|
|
|
|
|
|
|
conn.execute("""
|
|
|
UPDATE storage_blocks
|
|
|
SET last_accessed = ?
|
|
|
WHERE block_id = ?
|
|
|
""", (time.time(), block_id))
|
|
|
|
|
|
|
|
|
shape = tuple(map(int, shape_str.strip('()').split(',')))
|
|
|
|
|
|
|
|
|
return np.frombuffer(data_bytes, dtype=dtype_str).reshape(shape)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Database error retrieving tensor {tensor_id}: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
def _get_thread_lock(self, thread_id: Optional[int]) -> threading.Lock:
|
|
|
"""Get or create a lock for a thread"""
|
|
|
if thread_id is None:
|
|
|
return self.global_lock
|
|
|
|
|
|
with self.global_lock:
|
|
|
if thread_id not in self.thread_locks:
|
|
|
self.thread_locks[thread_id] = threading.Lock()
|
|
|
return self.thread_locks[thread_id]
|
|
|
|
|
|
def _update_thread_stats(self, thread_id: Optional[int], bytes_processed: int):
|
|
|
"""Update thread operation statistics"""
|
|
|
if thread_id is None:
|
|
|
return
|
|
|
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
conn.execute("""
|
|
|
INSERT INTO thread_stats (thread_id, ops_count, total_bytes, last_active)
|
|
|
VALUES (?, 1, ?, ?)
|
|
|
ON CONFLICT(thread_id) DO UPDATE SET
|
|
|
ops_count = ops_count + 1,
|
|
|
total_bytes = total_bytes + excluded.total_bytes,
|
|
|
last_active = excluded.last_active
|
|
|
""", (thread_id, bytes_processed, time.time()))
|
|
|
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
def get_thread_stats(self, thread_id: int) -> Dict[str, Any]:
|
|
|
"""Get statistics for a specific thread"""
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
result = conn.execute("""
|
|
|
SELECT ops_count, total_bytes, last_active
|
|
|
FROM thread_stats
|
|
|
WHERE thread_id = ?
|
|
|
""", [thread_id]).fetchone()
|
|
|
|
|
|
if result:
|
|
|
ops_count, total_bytes, last_active = result
|
|
|
return {
|
|
|
'ops_count': ops_count,
|
|
|
'total_bytes': total_bytes,
|
|
|
'last_active': last_active
|
|
|
}
|
|
|
return {}
|
|
|
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
def cleanup_thread(self, thread_id: int):
|
|
|
"""Cleanup resources for a terminated thread"""
|
|
|
with self.global_lock:
|
|
|
if thread_id in self.thread_locks:
|
|
|
del self.thread_locks[thread_id]
|
|
|
|
|
|
conn = self.connection_pool.get_connection()
|
|
|
try:
|
|
|
|
|
|
conn.execute("""
|
|
|
UPDATE storage_blocks
|
|
|
SET thread_id = NULL, is_locked = FALSE
|
|
|
WHERE thread_id = ?
|
|
|
""", [thread_id])
|
|
|
|
|
|
|
|
|
conn.execute("""
|
|
|
DELETE FROM thread_stats
|
|
|
WHERE thread_id = ?
|
|
|
""", [thread_id])
|
|
|
|
|
|
finally:
|
|
|
self.connection_pool.return_connection(conn)
|
|
|
|
|
|
def close(self):
|
|
|
"""Cleanup and close storage"""
|
|
|
self.read_executor.shutdown(wait=True)
|
|
|
self.write_executor.shutdown(wait=True)
|
|
|
|
|
|
|
|
|
while not self.connection_pool.connections.empty():
|
|
|
conn = self.connection_pool.connections.get()
|
|
|
conn.close()
|
|
|
|