INV / tensor_storage.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Advanced tensor storage operations extending the base storage system.
"""
import numpy as np
from typing import Dict, Any, Optional, Union, List, Tuple
import threading
from http_storage import LocalStorage
import logging
import json
class TensorOps:
"""Tensor operations for storage system"""
@staticmethod
def serialize_tensor(tensor: np.ndarray) -> Tuple[bytes, Dict[str, Any]]:
"""Serialize tensor to bytes with metadata"""
metadata = {
'shape': tensor.shape,
'dtype': str(tensor.dtype),
'strides': tensor.strides
}
return tensor.tobytes(), metadata
@staticmethod
def deserialize_tensor(data: bytes, metadata: Dict[str, Any]) -> np.ndarray:
"""Deserialize tensor from bytes and metadata"""
tensor = np.frombuffer(data, dtype=np.dtype(metadata['dtype']))
return tensor.reshape(metadata['shape'])
class TensorStorage(LocalStorage):
"""
Enhanced storage implementation with tensor operations.
Extends LocalStorage with advanced tensor manipulation capabilities.
"""
def __init__(self, db_url: str = None):
super().__init__(db_url)
self._ops = TensorOps()
def store_tensor(self, tensor_id: str, tensor: np.ndarray, metadata: Optional[Dict] = None) -> bool:
"""Store a tensor with its metadata"""
try:
# Serialize tensor
data, tensor_meta = self._ops.serialize_tensor(tensor)
# Merge with additional metadata
if metadata:
tensor_meta.update(metadata)
# Store in database
success = self._store_in_db('tensors', tensor_id, data, tensor_meta)
if success:
# Update stats
with self._lock:
self.stats['tensor_count'] += 1
self.stats['total_size'] += len(data)
return success
except Exception as e:
logging.error(f"Error storing tensor {tensor_id}: {str(e)}")
return False
def get_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
"""Retrieve a tensor by ID"""
try:
# Get from database
result = self.conn.execute("""
SELECT data, metadata
FROM tensors
WHERE id = ?
""", [tensor_id]).fetchone()
if not result:
return None
data, metadata = result
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
# Deserialize tensor
return self._ops.deserialize_tensor(data, metadata)
except Exception as e:
logging.error(f"Error retrieving tensor {tensor_id}: {str(e)}")
return None
def matmul(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
"""Perform matrix multiplication between two stored tensors"""
tensor_a = self.get_tensor(tensor_id_a)
tensor_b = self.get_tensor(tensor_id_b)
if tensor_a is None or tensor_b is None:
return None
try:
result = np.matmul(tensor_a, tensor_b)
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in matrix multiplication: {str(e)}")
return None
def add(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
"""Add two stored tensors"""
tensor_a = self.get_tensor(tensor_id_a)
tensor_b = self.get_tensor(tensor_id_b)
if tensor_a is None or tensor_b is None:
return None
try:
result = tensor_a + tensor_b
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in tensor addition: {str(e)}")
return None
def multiply(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]:
"""Element-wise multiply two stored tensors"""
tensor_a = self.get_tensor(tensor_id_a)
tensor_b = self.get_tensor(tensor_id_b)
if tensor_a is None or tensor_b is None:
return None
try:
result = tensor_a * tensor_b
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in tensor multiplication: {str(e)}")
return None
def transpose(self, tensor_id: str, result_id: str) -> Optional[str]:
"""Transpose a stored tensor"""
tensor = self.get_tensor(tensor_id)
if tensor is None:
return None
try:
result = np.transpose(tensor)
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in tensor transpose: {str(e)}")
return None
def reshape(self, tensor_id: str, new_shape: Tuple[int, ...], result_id: str) -> Optional[str]:
"""Reshape a stored tensor"""
tensor = self.get_tensor(tensor_id)
if tensor is None:
return None
try:
result = tensor.reshape(new_shape)
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in tensor reshape: {str(e)}")
return None
def split(self, tensor_id: str, indices_or_sections: Union[int, List[int]], axis: int = 0) -> List[str]:
"""Split a stored tensor into multiple tensors"""
tensor = self.get_tensor(tensor_id)
if tensor is None:
return []
try:
# Split the tensor
split_results = np.split(tensor, indices_or_sections, axis=axis)
# Store each split result
result_ids = []
for i, split_tensor in enumerate(split_results):
result_id = f"{tensor_id}_split_{i}"
if self.store_tensor(result_id, split_tensor):
result_ids.append(result_id)
return result_ids
except Exception as e:
logging.error(f"Error in tensor split: {str(e)}")
return []
def concatenate(self, tensor_ids: List[str], result_id: str, axis: int = 0) -> Optional[str]:
"""Concatenate multiple stored tensors"""
tensors = [self.get_tensor(tid) for tid in tensor_ids]
if None in tensors:
return None
try:
result = np.concatenate(tensors, axis=axis)
if self.store_tensor(result_id, result):
return result_id
return None
except Exception as e:
logging.error(f"Error in tensor concatenation: {str(e)}")
return None