File size: 15,275 Bytes
aa654a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

import torch
from typing import List, Dict, Callable, Optional, Any
import logging
import time
import uuid
import random # Added for sampling

# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class TensorStorage:
    """
    Manages datasets stored as collections of tensors in memory.
    """

    def __init__(self):
        """Initializes the TensorStorage with an empty dictionary for datasets."""
        # In-memory storage. Replace with persistent storage solution for production.
        # Structure: { dataset_name: { "tensors": List[Tensor], "metadata": List[Dict] } }
        self.datasets: Dict[str, Dict[str, List[Any]]] = {}
        logging.info("TensorStorage initialized (In-Memory).")

    def create_dataset(self, name: str) -> None:
        """
        Creates a new, empty dataset.

        Args:
            name: The unique name for the new dataset.

        Raises:
            ValueError: If a dataset with the same name already exists.
        """
        if name in self.datasets:
            logging.warning(f"Attempted to create dataset '{name}' which already exists.")
            raise ValueError(f"Dataset '{name}' already exists.")

        self.datasets[name] = {"tensors": [], "metadata": []}
        logging.info(f"Dataset '{name}' created successfully.")

    def insert(self, name: str, tensor: torch.Tensor, metadata: Optional[Dict[str, Any]] = None) -> str:
        """
        Inserts a tensor into a specified dataset.

        Args:
            name: The name of the dataset to insert into.
            tensor: The PyTorch tensor to insert.
            metadata: Optional dictionary containing metadata about the tensor
                      (e.g., source, timestamp, custom tags).

        Returns:
            str: A unique ID assigned to the inserted tensor record.

        Raises:
            ValueError: If the dataset does not exist.
            TypeError: If the provided object is not a PyTorch tensor.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for insertion.")
            raise ValueError(f"Dataset '{name}' does not exist. Create it first.")

        if not isinstance(tensor, torch.Tensor):
            logging.error(f"Attempted to insert non-tensor data into dataset '{name}'.")
            raise TypeError("Data to be inserted must be a torch.Tensor.")

        # Ensure metadata consistency if not provided
        if metadata is None:
            metadata = {} # Start with empty dict if none provided

        # Basic metadata generation
        record_id = str(uuid.uuid4())
        default_metadata = {
            "record_id": record_id,
            "timestamp_utc": time.time(),
            "shape": tuple(tensor.shape),
            "dtype": str(tensor.dtype),
            # Placeholder for versioning - simple sequence for now
            "version": len(self.datasets[name]["tensors"]) + 1,
        }
        # Update default_metadata with provided metadata, overwriting reserved keys if necessary
        # Check for reserved keys before updating
        for key in default_metadata:
            if key in metadata and key != 'record_id': # Allow users to specify record_id if really needed, though risky
                logging.warning(f"Provided metadata key '{key}' might conflict with generated defaults.")

        # Merge: user-provided metadata takes precedence for non-essential fields
        # but essential fields from default_metadata are always included.
        final_metadata = {**metadata, **default_metadata} # Default values overwrite if keys conflict (like record_id)
        final_metadata.update(metadata) # Ensure user metadata takes priority after defaults are set

        # --- Placeholder for Chunking Logic ---
        # In a real implementation, large tensors would be chunked here.
        # Each chunk would be stored separately with associated metadata.
        # For now, we store the whole tensor.
        # ------------------------------------

        self.datasets[name]["tensors"].append(tensor.clone()) # Store a copy
        self.datasets[name]["metadata"].append(final_metadata)
        logging.debug(f"Tensor with shape {tuple(tensor.shape)} inserted into dataset '{name}'. Record ID: {record_id}")
        return record_id # Return the generated ID


    def get_dataset(self, name: str) -> List[torch.Tensor]:
        """
        Retrieves all tensors from a specified dataset.

        Args:
            name: The name of the dataset to retrieve.

        Returns:
            A list of all tensors in the dataset.

        Raises:
            ValueError: If the dataset does not exist.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for retrieval.")
            raise ValueError(f"Dataset '{name}' does not exist.")

        logging.debug(f"Retrieving all {len(self.datasets[name]['tensors'])} tensors from dataset '{name}'.")
        # --- Placeholder for Reassembling Chunks ---
        # If data was chunked, it would be reassembled here before returning.
        # -----------------------------------------
        return self.datasets[name]["tensors"]

    def get_dataset_with_metadata(self, name: str) -> List[Dict[str, Any]]:
        """
        Retrieves all tensors and their metadata from a specified dataset.

        Args:
            name: The name of the dataset to retrieve.

        Returns:
            A list of dictionaries, each containing a 'tensor' and its 'metadata'.

        Raises:
            ValueError: If the dataset does not exist.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for retrieval with metadata.")
            raise ValueError(f"Dataset '{name}' does not exist.")

        logging.debug(f"Retrieving all {len(self.datasets[name]['tensors'])} tensors and metadata from dataset '{name}'.")

        results = []
        for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
            results.append({"tensor": tensor, "metadata": meta})
        return results


    def query(self, name: str, query_fn: Callable[[torch.Tensor, Dict[str, Any]], bool]) -> List[Dict[str, Any]]:
        """
        Queries a dataset using a function that filters tensors based on the
        tensor data itself and/or its metadata.

        Args:
            name: The name of the dataset to query.
            query_fn: A callable that takes a tensor and its metadata dictionary
                      as input and returns True if the tensor should be included
                      in the result, False otherwise.

        Returns:
            A list of dictionaries, each containing a 'tensor' and its 'metadata'
            that satisfy the query function.

        Raises:
            ValueError: If the dataset does not exist.
            TypeError: If query_fn is not callable.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for querying.")
            raise ValueError(f"Dataset '{name}' does not exist.")

        if not callable(query_fn):
             logging.error(f"Provided query_fn is not callable for dataset '{name}'.")
             raise TypeError("query_fn must be a callable function.")

        logging.debug(f"Querying dataset '{name}' with custom function.")
        results = []
        # --- Placeholder for Optimized Querying ---
        # In a real system, metadata indexing would speed this up significantly.
        # Query might operate directly on chunks or specific metadata fields first.
        # ----------------------------------------
        for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
            try:
                if query_fn(tensor, meta):
                    results.append({"tensor": tensor, "metadata": meta})
            except Exception as e:
                logging.warning(f"Error executing query_fn on tensor {meta.get('record_id', 'N/A')} in dataset '{name}': {e}")
                # Optionally re-raise or continue based on desired strictness
                continue

        logging.info(f"Query on dataset '{name}' returned {len(results)} results.")
        return results # Returns List of dictionaries, each containing 'tensor' and 'metadata'


    def get_tensor_by_id(self, name: str, record_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a specific tensor and its metadata by its unique record ID.

        Args:
            name: The name of the dataset.
            record_id: The unique ID of the record to retrieve.

        Returns:
            A dictionary containing the 'tensor' and 'metadata', or None if not found.

        Raises:
            ValueError: If the dataset does not exist.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for get_tensor_by_id.")
            raise ValueError(f"Dataset '{name}' does not exist.")

        # This is inefficient for large datasets; requires an index in a real system.
        for tensor, meta in zip(self.datasets[name]["tensors"], self.datasets[name]["metadata"]):
             if meta.get("record_id") == record_id:
                 logging.debug(f"Tensor with record_id '{record_id}' found in dataset '{name}'.")
                 return {"tensor": tensor, "metadata": meta}

        logging.warning(f"Tensor with record_id '{record_id}' not found in dataset '{name}'.")
        return None

    # --- ADDED METHOD (from Step 3) ---
    def sample_dataset(self, name: str, n_samples: int) -> List[Dict[str, Any]]:
        """
        Retrieves a random sample of records (tensor and metadata) from a dataset.

        Args:
            name: The name of the dataset to sample from.
            n_samples: The number of samples to retrieve.

        Returns:
            A list of dictionaries, each containing 'tensor' and 'metadata' for
            the sampled records. Returns fewer than n_samples if the dataset is smaller.

        Raises:
            ValueError: If the dataset does not exist.
        """
        if name not in self.datasets:
            logging.error(f"Dataset '{name}' not found for sampling.")
            raise ValueError(f"Dataset '{name}' does not exist.")

        dataset_size = len(self.datasets[name]["tensors"])
        if n_samples <= 0:
             return []
        if n_samples >= dataset_size:
            logging.warning(f"Requested {n_samples} samples from dataset '{name}' which only has {dataset_size} items. Returning all items shuffled.")
            # Return all items shuffled if n_samples >= dataset_size
            indices = list(range(dataset_size))
            random.shuffle(indices)
        else:
            indices = random.sample(range(dataset_size), n_samples)

        logging.debug(f"Sampling {len(indices)} records from dataset '{name}'.")

        # In-memory sampling is easy. For persistent storage, this would
        # likely involve optimized queries or index lookups.
        sampled_records = []
        for i in indices:
            sampled_records.append({
                "tensor": self.datasets[name]["tensors"][i],
                "metadata": self.datasets[name]["metadata"][i]
            })

        return sampled_records

    def delete_dataset(self, name: str) -> bool:
        """
        Deletes an entire dataset. Use with caution!

        Args:
            name: The name of the dataset to delete.

        Returns:
            True if the dataset was deleted, False if it didn't exist.
        """
        if name in self.datasets:
            del self.datasets[name]
            logging.warning(f"Dataset '{name}' has been permanently deleted.")
            return True
        else:
            logging.warning(f"Attempted to delete non-existent dataset '{name}'.")
            return False

# Example Usage (can be run directly if needed)
if __name__ == "__main__":
    storage = TensorStorage()

    # Create datasets
    storage.create_dataset("images")
    storage.create_dataset("sensor_readings")

    # Insert tensors
    img_tensor = torch.rand(3, 64, 64) # Example image tensor (Channels, H, W)
    sensor_tensor1 = torch.tensor([10.5, 11.2, 10.9])
    sensor_tensor2 = torch.tensor([11.1, 11.5, 11.3])
    sensor_tensor3 = torch.tensor([9.8, 10.1, 9.9])

    img_id = storage.insert("images", img_tensor, metadata={"source": "camera_A", "label": "cat"})
    sensor_id1 = storage.insert("sensor_readings", sensor_tensor1, metadata={"sensor_id": "XYZ", "location": "lab1"})
    sensor_id2 = storage.insert("sensor_readings", sensor_tensor2, metadata={"sensor_id": "XYZ", "location": "lab1"})
    sensor_id3 = storage.insert("sensor_readings", sensor_tensor3, metadata={"sensor_id": "ABC", "location": "lab2"})


    print(f"Inserted image with ID: {img_id}")
    print(f"Inserted sensor reading 1 with ID: {sensor_id1}")
    print(f"Inserted sensor reading 2 with ID: {sensor_id2}")
    print(f"Inserted sensor reading 3 with ID: {sensor_id3}")


    # Retrieve a dataset
    all_sensor_tensors_meta = storage.get_dataset_with_metadata("sensor_readings")
    print(f"\nRetrieved {len(all_sensor_tensors_meta)} sensor records:")
    for item in all_sensor_tensors_meta:
        print(f"  Metadata: {item['metadata']}, Tensor shape: {item['tensor'].shape}")

    # Query a dataset
    print("\nQuerying sensor readings with first value > 11.0:")
    query_result = storage.query(
        "sensor_readings",
        lambda tensor, meta: tensor[0].item() > 11.0
    )
    for item in query_result:
        print(f"  Metadata: {item['metadata']}, Tensor: {item['tensor']}")

    print("\nQuerying sensor readings from sensor 'XYZ':")
    query_result_meta = storage.query(
        "sensor_readings",
        lambda tensor, meta: meta.get("sensor_id") == "XYZ"
    )
    for item in query_result_meta:
        print(f"  Metadata: {item['metadata']}, Tensor: {item['tensor']}")


    # Retrieve by ID
    print(f"\nRetrieving sensor reading with ID {sensor_id1}:")
    retrieved_item = storage.get_tensor_by_id("sensor_readings", sensor_id1)
    if retrieved_item:
        print(f"  Metadata: {retrieved_item['metadata']}, Tensor: {retrieved_item['tensor']}")

    # Sample the dataset
    print(f"\nSampling 2 records from sensor_readings:")
    sampled_items = storage.sample_dataset("sensor_readings", 2)
    print(f" Got {len(sampled_items)} samples.")
    for i, item in enumerate(sampled_items):
         print(f"  Sample {i+1} - Record ID: {item['metadata'].get('record_id')}, Tensor shape: {item['tensor'].shape}")

    print(f"\nSampling 5 records (more than available):")
    sampled_items_all = storage.sample_dataset("sensor_readings", 5)
    print(f" Got {len(sampled_items_all)} samples.")
    for i, item in enumerate(sampled_items_all):
         print(f"  Sample {i+1} - Record ID: {item['metadata'].get('record_id')}") # Showing IDs to see shuffle

    # Delete a dataset
    storage.delete_dataset("images")
    try:
        storage.get_dataset("images")
    except ValueError as e:
        print(f"\nSuccessfully deleted 'images' dataset: {e}")