|
|
"""
|
|
|
Advanced tensor storage operations extending the base storage system.
|
|
|
"""
|
|
|
|
|
|
import numpy as np
|
|
|
from typing import Dict, Any, Optional, Union, List, Tuple
|
|
|
import threading
|
|
|
from http_storage import LocalStorage
|
|
|
import logging
|
|
|
import json
|
|
|
|
|
|
class TensorOps:
|
|
|
"""Tensor operations for storage system"""
|
|
|
|
|
|
@staticmethod
|
|
|
def serialize_tensor(tensor: np.ndarray) -> Tuple[bytes, Dict[str, Any]]:
|
|
|
"""Serialize tensor to bytes with metadata"""
|
|
|
metadata = {
|
|
|
'shape': tensor.shape,
|
|
|
'dtype': str(tensor.dtype),
|
|
|
'strides': tensor.strides
|
|
|
}
|
|
|
return tensor.tobytes(), metadata
|
|
|
|
|
|
@staticmethod
|
|
|
def deserialize_tensor(data: bytes, metadata: Dict[str, Any]) -> np.ndarray:
|
|
|
"""Deserialize tensor from bytes and metadata"""
|
|
|
tensor = np.frombuffer(data, dtype=np.dtype(metadata['dtype']))
|
|
|
return tensor.reshape(metadata['shape'])
|
|
|
|
|
|
class TensorStorage(LocalStorage):
|
|
|
"""
|
|
|
Enhanced storage implementation with tensor operations.
|
|
|
Extends LocalStorage with advanced tensor manipulation capabilities.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, db_url: str = None):
|
|
|
super().__init__(db_url)
|
|
|
self._ops = TensorOps()
|
|
|
|
|
|
def store_tensor(self, tensor_id: str, tensor: np.ndarray, metadata: Optional[Dict] = None) -> bool:
|
|
|
"""Store a tensor with its metadata"""
|
|
|
try:
|
|
|
|
|
|
data, tensor_meta = self._ops.serialize_tensor(tensor)
|
|
|
|
|
|
|
|
|
if metadata:
|
|
|
tensor_meta.update(metadata)
|
|
|
|
|
|
|
|
|
success = self._store_in_db('tensors', tensor_id, data, tensor_meta)
|
|
|
|
|
|
if success:
|
|
|
|
|
|
with self._lock:
|
|
|
self.stats['tensor_count'] += 1
|
|
|
self.stats['total_size'] += len(data)
|
|
|
|
|
|
return success
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def get_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
|
|
|
"""Retrieve a tensor by ID"""
|
|
|
try:
|
|
|
|
|
|
result = self.conn.execute("""
|
|
|
SELECT data, metadata
|
|
|
FROM tensors
|
|
|
WHERE id = ?
|
|
|
""", [tensor_id]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
return None
|
|
|
|
|
|
data, metadata = result
|
|
|
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
|
|
|
|
|
|
|
|
return self._ops.deserialize_tensor(data, metadata)
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error retrieving tensor {tensor_id}: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def matmul(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
|
|
|
"""Perform matrix multiplication between two stored tensors"""
|
|
|
tensor_a = self.get_tensor(tensor_id_a)
|
|
|
tensor_b = self.get_tensor(tensor_id_b)
|
|
|
|
|
|
if tensor_a is None or tensor_b is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = np.matmul(tensor_a, tensor_b)
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in matrix multiplication: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def add(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
|
|
|
"""Add two stored tensors"""
|
|
|
tensor_a = self.get_tensor(tensor_id_a)
|
|
|
tensor_b = self.get_tensor(tensor_id_b)
|
|
|
|
|
|
if tensor_a is None or tensor_b is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = tensor_a + tensor_b
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor addition: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def multiply(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
|
|
|
"""Element-wise multiply two stored tensors"""
|
|
|
tensor_a = self.get_tensor(tensor_id_a)
|
|
|
tensor_b = self.get_tensor(tensor_id_b)
|
|
|
|
|
|
if tensor_a is None or tensor_b is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = tensor_a * tensor_b
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor multiplication: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def transpose(self, tensor_id: str, result_id: str) -> Optional[str]:
|
|
|
"""Transpose a stored tensor"""
|
|
|
tensor = self.get_tensor(tensor_id)
|
|
|
|
|
|
if tensor is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = np.transpose(tensor)
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor transpose: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def reshape(self, tensor_id: str, new_shape: Tuple[int, ...], result_id: str) -> Optional[str]:
|
|
|
"""Reshape a stored tensor"""
|
|
|
tensor = self.get_tensor(tensor_id)
|
|
|
|
|
|
if tensor is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = tensor.reshape(new_shape)
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor reshape: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def split(self, tensor_id: str, indices_or_sections: Union[int, List[int]], axis: int = 0) -> List[str]:
|
|
|
"""Split a stored tensor into multiple tensors"""
|
|
|
tensor = self.get_tensor(tensor_id)
|
|
|
|
|
|
if tensor is None:
|
|
|
return []
|
|
|
|
|
|
try:
|
|
|
|
|
|
split_results = np.split(tensor, indices_or_sections, axis=axis)
|
|
|
|
|
|
|
|
|
result_ids = []
|
|
|
for i, split_tensor in enumerate(split_results):
|
|
|
result_id = f"{tensor_id}_split_{i}"
|
|
|
if self.store_tensor(result_id, split_tensor):
|
|
|
result_ids.append(result_id)
|
|
|
|
|
|
return result_ids
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor split: {str(e)}")
|
|
|
return []
|
|
|
|
|
|
def concatenate(self, tensor_ids: List[str], result_id: str, axis: int = 0) -> Optional[str]:
|
|
|
"""Concatenate multiple stored tensors"""
|
|
|
tensors = [self.get_tensor(tid) for tid in tensor_ids]
|
|
|
|
|
|
if None in tensors:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
result = np.concatenate(tensors, axis=axis)
|
|
|
if self.store_tensor(result_id, result):
|
|
|
return result_id
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error in tensor concatenation: {str(e)}")
|
|
|
return None
|
|
|
|