File size: 15,275 Bytes
aa654a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
import torch
from typing import List, Dict, Callable, Optional, Any
import logging
import time
import uuid
import random # Added for sampling
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class TensorStorage:
"""
Manages datasets stored as collections of tensors in memory.
"""
def __init__(self):
"""Initializes the TensorStorage with an empty dictionary for datasets."""
# In-memory storage. Replace with persistent storage solution for production.
# Structure: { dataset_name: { "tensors": List[Tensor], "metadata": List[Dict] } }
self.datasets: Dict[str, Dict[str, List[Any]]] = {}
logging.info("TensorStorage initialized (In-Memory).")
def create_dataset(self, name: str) -> None:
"""
Creates a new, empty dataset.
Args:
name: The unique name for the new dataset.
Raises:
ValueError: If a dataset with the same name already exists.
"""
if name in self.datasets:
logging.warning(f"Attempted to create dataset '{name}' which already exists.")
raise ValueError(f"Dataset '{name}' already exists.")
self.datasets[name] = {"tensors": [], "metadata": []}
logging.info(f"Dataset '{name}' created successfully.")
def insert(self, name: str, tensor: torch.Tensor, metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Inserts a tensor into a specified dataset.
Args:
name: The name of the dataset to insert into.
tensor: The PyTorch tensor to insert.
metadata: Optional dictionary containing metadata about the tensor
(e.g., source, timestamp, custom tags).
Returns:
str: A unique ID assigned to the inserted tensor record.
Raises:
ValueError: If the dataset does not exist.
TypeError: If the provided object is not a PyTorch tensor.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for insertion.")
raise ValueError(f"Dataset '{name}' does not exist. Create it first.")
if not isinstance(tensor, torch.Tensor):
logging.error(f"Attempted to insert non-tensor data into dataset '{name}'.")
raise TypeError("Data to be inserted must be a torch.Tensor.")
# Ensure metadata consistency if not provided
if metadata is None:
metadata = {} # Start with empty dict if none provided
# Basic metadata generation
record_id = str(uuid.uuid4())
default_metadata = {
"record_id": record_id,
"timestamp_utc": time.time(),
"shape": tuple(tensor.shape),
"dtype": str(tensor.dtype),
# Placeholder for versioning - simple sequence for now
"version": len(self.datasets[name]["tensors"]) + 1,
}
# Update default_metadata with provided metadata, overwriting reserved keys if necessary
# Check for reserved keys before updating
for key in default_metadata:
if key in metadata and key != 'record_id': # Allow users to specify record_id if really needed, though risky
logging.warning(f"Provided metadata key '{key}' might conflict with generated defaults.")
# Merge: user-provided metadata takes precedence for non-essential fields
# but essential fields from default_metadata are always included.
final_metadata = {**metadata, **default_metadata} # Default values overwrite if keys conflict (like record_id)
final_metadata.update(metadata) # Ensure user metadata takes priority after defaults are set
# --- Placeholder for Chunking Logic ---
# In a real implementation, large tensors would be chunked here.
# Each chunk would be stored separately with associated metadata.
# For now, we store the whole tensor.
# ------------------------------------
self.datasets[name]["tensors"].append(tensor.clone()) # Store a copy
self.datasets[name]["metadata"].append(final_metadata)
logging.debug(f"Tensor with shape {tuple(tensor.shape)} inserted into dataset '{name}'. Record ID: {record_id}")
return record_id # Return the generated ID
def get_dataset(self, name: str) -> List[torch.Tensor]:
"""
Retrieves all tensors from a specified dataset.
Args:
name: The name of the dataset to retrieve.
Returns:
A list of all tensors in the dataset.
Raises:
ValueError: If the dataset does not exist.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for retrieval.")
raise ValueError(f"Dataset '{name}' does not exist.")
logging.debug(f"Retrieving all {len(self.datasets[name]['tensors'])} tensors from dataset '{name}'.")
# --- Placeholder for Reassembling Chunks ---
# If data was chunked, it would be reassembled here before returning.
# -----------------------------------------
return self.datasets[name]["tensors"]
def get_dataset_with_metadata(self, name: str) -> List[Dict[str, Any]]:
"""
Retrieves all tensors and their metadata from a specified dataset.
Args:
name: The name of the dataset to retrieve.
Returns:
A list of dictionaries, each containing a 'tensor' and its 'metadata'.
Raises:
ValueError: If the dataset does not exist.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for retrieval with metadata.")
raise ValueError(f"Dataset '{name}' does not exist.")
logging.debug(f"Retrieving all {len(self.datasets[name]['tensors'])} tensors and metadata from dataset '{name}'.")
results = []
for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
results.append({"tensor": tensor, "metadata": meta})
return results
def query(self, name: str, query_fn: Callable[[torch.Tensor, Dict[str, Any]], bool]) -> List[Dict[str, Any]]:
"""
Queries a dataset using a function that filters tensors based on the
tensor data itself and/or its metadata.
Args:
name: The name of the dataset to query.
query_fn: A callable that takes a tensor and its metadata dictionary
as input and returns True if the tensor should be included
in the result, False otherwise.
Returns:
A list of dictionaries, each containing a 'tensor' and its 'metadata'
that satisfy the query function.
Raises:
ValueError: If the dataset does not exist.
TypeError: If query_fn is not callable.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for querying.")
raise ValueError(f"Dataset '{name}' does not exist.")
if not callable(query_fn):
logging.error(f"Provided query_fn is not callable for dataset '{name}'.")
raise TypeError("query_fn must be a callable function.")
logging.debug(f"Querying dataset '{name}' with custom function.")
results = []
# --- Placeholder for Optimized Querying ---
# In a real system, metadata indexing would speed this up significantly.
# Query might operate directly on chunks or specific metadata fields first.
# ----------------------------------------
for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
try:
if query_fn(tensor, meta):
results.append({"tensor": tensor, "metadata": meta})
except Exception as e:
logging.warning(f"Error executing query_fn on tensor {meta.get('record_id', 'N/A')} in dataset '{name}': {e}")
# Optionally re-raise or continue based on desired strictness
continue
logging.info(f"Query on dataset '{name}' returned {len(results)} results.")
return results # Returns List of dictionaries, each containing 'tensor' and 'metadata'
def get_tensor_by_id(self, name: str, record_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieves a specific tensor and its metadata by its unique record ID.
Args:
name: The name of the dataset.
record_id: The unique ID of the record to retrieve.
Returns:
A dictionary containing the 'tensor' and 'metadata', or None if not found.
Raises:
ValueError: If the dataset does not exist.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for get_tensor_by_id.")
raise ValueError(f"Dataset '{name}' does not exist.")
# This is inefficient for large datasets; requires an index in a real system.
for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
if meta.get("record_id") == record_id:
logging.debug(f"Tensor with record_id '{record_id}' found in dataset '{name}'.")
return {"tensor": tensor, "metadata": meta}
logging.warning(f"Tensor with record_id '{record_id}' not found in dataset '{name}'.")
return None
# --- ADDED METHOD (from Step 3) ---
def sample_dataset(self, name: str, n_samples: int) -> List[Dict[str, Any]]:
"""
Retrieves a random sample of records (tensor and metadata) from a dataset.
Args:
name: The name of the dataset to sample from.
n_samples: The number of samples to retrieve.
Returns:
A list of dictionaries, each containing 'tensor' and 'metadata' for
the sampled records. Returns fewer than n_samples if the dataset is smaller.
Raises:
ValueError: If the dataset does not exist.
"""
if name not in self.datasets:
logging.error(f"Dataset '{name}' not found for sampling.")
raise ValueError(f"Dataset '{name}' does not exist.")
dataset_size = len(self.datasets[name]["tensors"])
if n_samples <= 0:
return []
if n_samples >= dataset_size:
logging.warning(f"Requested {n_samples} samples from dataset '{name}' which only has {dataset_size} items. Returning all items shuffled.")
# Return all items shuffled if n_samples >= dataset_size
indices = list(range(dataset_size))
random.shuffle(indices)
else:
indices = random.sample(range(dataset_size), n_samples)
logging.debug(f"Sampling {len(indices)} records from dataset '{name}'.")
# In-memory sampling is easy. For persistent storage, this would
# likely involve optimized queries or index lookups.
sampled_records = []
for i in indices:
sampled_records.append({
"tensor": self.datasets[name]["tensors"][i],
"metadata": self.datasets[name]["metadata"][i]
})
return sampled_records
def delete_dataset(self, name: str) -> bool:
"""
Deletes an entire dataset. Use with caution!
Args:
name: The name of the dataset to delete.
Returns:
True if the dataset was deleted, False if it didn't exist.
"""
if name in self.datasets:
del self.datasets[name]
logging.warning(f"Dataset '{name}' has been permanently deleted.")
return True
else:
logging.warning(f"Attempted to delete non-existent dataset '{name}'.")
return False
# Example Usage (can be run directly if needed)
if __name__ == "__main__":
storage = TensorStorage()
# Create datasets
storage.create_dataset("images")
storage.create_dataset("sensor_readings")
# Insert tensors
img_tensor = torch.rand(3, 64, 64) # Example image tensor (Channels, H, W)
sensor_tensor1 = torch.tensor([10.5, 11.2, 10.9])
sensor_tensor2 = torch.tensor([11.1, 11.5, 11.3])
sensor_tensor3 = torch.tensor([9.8, 10.1, 9.9])
img_id = storage.insert("images", img_tensor, metadata={"source": "camera_A", "label": "cat"})
sensor_id1 = storage.insert("sensor_readings", sensor_tensor1, metadata={"sensor_id": "XYZ", "location": "lab1"})
sensor_id2 = storage.insert("sensor_readings", sensor_tensor2, metadata={"sensor_id": "XYZ", "location": "lab1"})
sensor_id3 = storage.insert("sensor_readings", sensor_tensor3, metadata={"sensor_id": "ABC", "location": "lab2"})
print(f"Inserted image with ID: {img_id}")
print(f"Inserted sensor reading 1 with ID: {sensor_id1}")
print(f"Inserted sensor reading 2 with ID: {sensor_id2}")
print(f"Inserted sensor reading 3 with ID: {sensor_id3}")
# Retrieve a dataset
all_sensor_tensors_meta = storage.get_dataset_with_metadata("sensor_readings")
print(f"\nRetrieved {len(all_sensor_tensors_meta)} sensor records:")
for item in all_sensor_tensors_meta:
print(f" Metadata: {item['metadata']}, Tensor shape: {item['tensor'].shape}")
# Query a dataset
print("\nQuerying sensor readings with first value > 11.0:")
query_result = storage.query(
"sensor_readings",
lambda tensor, meta: tensor[0].item() > 11.0
)
for item in query_result:
print(f" Metadata: {item['metadata']}, Tensor: {item['tensor']}")
print("\nQuerying sensor readings from sensor 'XYZ':")
query_result_meta = storage.query(
"sensor_readings",
lambda tensor, meta: meta.get("sensor_id") == "XYZ"
)
for item in query_result_meta:
print(f" Metadata: {item['metadata']}, Tensor: {item['tensor']}")
# Retrieve by ID
print(f"\nRetrieving sensor reading with ID {sensor_id1}:")
retrieved_item = storage.get_tensor_by_id("sensor_readings", sensor_id1)
if retrieved_item:
print(f" Metadata: {retrieved_item['metadata']}, Tensor: {retrieved_item['tensor']}")
# Sample the dataset
print(f"\nSampling 2 records from sensor_readings:")
sampled_items = storage.sample_dataset("sensor_readings", 2)
print(f" Got {len(sampled_items)} samples.")
for i, item in enumerate(sampled_items):
print(f" Sample {i+1} - Record ID: {item['metadata'].get('record_id')}, Tensor shape: {item['tensor'].shape}")
print(f"\nSampling 5 records (more than available):")
sampled_items_all = storage.sample_dataset("sensor_readings", 5)
print(f" Got {len(sampled_items_all)} samples.")
for i, item in enumerate(sampled_items_all):
print(f" Sample {i+1} - Record ID: {item['metadata'].get('record_id')}") # Showing IDs to see shuffle
# Delete a dataset
storage.delete_dataset("images")
try:
storage.get_dataset("images")
except ValueError as e:
print(f"\nSuccessfully deleted 'images' dataset: {e}")
|