""" Advanced tensor storage operations extending the base storage system. """ import numpy as np from typing import Dict, Any, Optional, Union, List, Tuple import threading from http_storage import LocalStorage import logging import json class TensorOps: """Tensor operations for storage system""" @staticmethod def serialize_tensor(tensor: np.ndarray) -> Tuple[bytes, Dict[str, Any]]: """Serialize tensor to bytes with metadata""" metadata = { 'shape': tensor.shape, 'dtype': str(tensor.dtype), 'strides': tensor.strides } return tensor.tobytes(), metadata @staticmethod def deserialize_tensor(data: bytes, metadata: Dict[str, Any]) -> np.ndarray: """Deserialize tensor from bytes and metadata""" tensor = np.frombuffer(data, dtype=np.dtype(metadata['dtype'])) return tensor.reshape(metadata['shape']) class TensorStorage(LocalStorage): """ Enhanced storage implementation with tensor operations. Extends LocalStorage with advanced tensor manipulation capabilities. """ def __init__(self, db_url: str = None): super().__init__(db_url) self._ops = TensorOps() def store_tensor(self, tensor_id: str, tensor: np.ndarray, metadata: Optional[Dict] = None) -> bool: """Store a tensor with its metadata""" try: # Serialize tensor data, tensor_meta = self._ops.serialize_tensor(tensor) # Merge with additional metadata if metadata: tensor_meta.update(metadata) # Store in database success = self._store_in_db('tensors', tensor_id, data, tensor_meta) if success: # Update stats with self._lock: self.stats['tensor_count'] += 1 self.stats['total_size'] += len(data) return success except Exception as e: logging.error(f"Error storing tensor {tensor_id}: {str(e)}") return False def get_tensor(self, tensor_id: str) -> Optional[np.ndarray]: """Retrieve a tensor by ID""" try: # Get from database result = self.conn.execute(""" SELECT data, metadata FROM tensors WHERE id = ? """, [tensor_id]).fetchone() if not result: return None data, metadata = result metadata = json.loads(metadata) if isinstance(metadata, str) else metadata # Deserialize tensor return self._ops.deserialize_tensor(data, metadata) except Exception as e: logging.error(f"Error retrieving tensor {tensor_id}: {str(e)}") return None def matmul(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]: """Perform matrix multiplication between two stored tensors""" tensor_a = self.get_tensor(tensor_id_a) tensor_b = self.get_tensor(tensor_id_b) if tensor_a is None or tensor_b is None: return None try: result = np.matmul(tensor_a, tensor_b) if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in matrix multiplication: {str(e)}") return None def add(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]: """Add two stored tensors""" tensor_a = self.get_tensor(tensor_id_a) tensor_b = self.get_tensor(tensor_id_b) if tensor_a is None or tensor_b is None: return None try: result = tensor_a + tensor_b if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in tensor addition: {str(e)}") return None def multiply(self, tensor_id_a: str, tensor_id_b: str, result_id: str) -> Optional[str]: """Element-wise multiply two stored tensors""" tensor_a = self.get_tensor(tensor_id_a) tensor_b = self.get_tensor(tensor_id_b) if tensor_a is None or tensor_b is None: return None try: result = tensor_a * tensor_b if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in tensor multiplication: {str(e)}") return None def transpose(self, tensor_id: str, result_id: str) -> Optional[str]: """Transpose a stored tensor""" tensor = self.get_tensor(tensor_id) if tensor is None: return None try: result = np.transpose(tensor) if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in tensor transpose: {str(e)}") return None def reshape(self, tensor_id: str, new_shape: Tuple[int, ...], result_id: str) -> Optional[str]: """Reshape a stored tensor""" tensor = self.get_tensor(tensor_id) if tensor is None: return None try: result = tensor.reshape(new_shape) if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in tensor reshape: {str(e)}") return None def split(self, tensor_id: str, indices_or_sections: Union[int, List[int]], axis: int = 0) -> List[str]: """Split a stored tensor into multiple tensors""" tensor = self.get_tensor(tensor_id) if tensor is None: return [] try: # Split the tensor split_results = np.split(tensor, indices_or_sections, axis=axis) # Store each split result result_ids = [] for i, split_tensor in enumerate(split_results): result_id = f"{tensor_id}_split_{i}" if self.store_tensor(result_id, split_tensor): result_ids.append(result_id) return result_ids except Exception as e: logging.error(f"Error in tensor split: {str(e)}") return [] def concatenate(self, tensor_ids: List[str], result_id: str, axis: int = 0) -> Optional[str]: """Concatenate multiple stored tensors""" tensors = [self.get_tensor(tid) for tid in tensor_ids] if None in tensors: return None try: result = np.concatenate(tensors, axis=axis) if self.store_tensor(result_id, result): return result_id return None except Exception as e: logging.error(f"Error in tensor concatenation: {str(e)}") return None