File size: 18,945 Bytes
aa654a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# nql_agent.py
"""
Implements the NQL (Natural Query Language) Query Agent for Tensorus.

This agent provides a basic natural language interface to query datasets
stored in TensorStorage. It uses regular expressions to parse simple queries
and translates them into calls to TensorStorage.query or other methods.

Limitations (without LLM):
- Understands only a very limited set of predefined sentence structures.
- Limited support for complex conditions (AND/OR not implemented).
- Limited support for data types in conditions (primarily numbers and exact strings).
- No support for aggregations (mean, sum, etc.) beyond simple counts.
- Error handling for parsing ambiguity is basic.

Future Enhancements:
- Integrate a local or remote LLM for robust NLU.
- Support for complex queries (multiple conditions, joins).
- Support for aggregations and projections (selecting specific fields).
- More sophisticated error handling and user feedback.
- Context awareness and conversation history.
"""

import re
import logging
import torch
from typing import List, Dict, Any, Optional, Callable, Tuple

from tensor_storage import TensorStorage # Import our storage module

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class NQLAgent:
    """Parses simple natural language queries and executes them against TensorStorage."""

    def __init__(self, tensor_storage: TensorStorage):
        """
        Initializes the NQL Agent.

        Args:
            tensor_storage: An instance of the TensorStorage class.
        """
        if not isinstance(tensor_storage, TensorStorage):
            raise TypeError("tensor_storage must be an instance of TensorStorage")
        self.tensor_storage = tensor_storage

        # --- Compile Regex Patterns for Query Parsing ---
        # Pattern to match variations of "get/show/find [all] X from dataset Y"
        self.pattern_get_all = re.compile(
            r"^(?:get|show|find)\s+(?:all\s+)?(?:data|tensors?|records?|entries|experiences?)\s+from\s+(?:dataset\s+)?([\w_.-]+)$",
            re.IGNORECASE
        )

        # Pattern for basic metadata filtering: "... from Y where meta_key op value"
        # Captures: dataset_name, key, operator, value
        # Value can be quoted or unquoted number/simple string
        self.pattern_filter_meta = re.compile(
            r"^(?:get|show|find)\s+.*\s+from\s+([\w_.-]+)\s+where\s+([\w_.-]+)\s*([<>=!]+)\s*'?([\w\s\d_.-]+?)'?$",
             re.IGNORECASE
        )
        # Simpler pattern allowing 'is'/'equals'/'eq' for '='
        self.pattern_filter_meta_alt = re.compile(
            r"^(?:get|show|find)\s+.*\s+from\s+([\w_.-]+)\s+where\s+([\w_.-]+)\s+(?:is|equals|eq)\s+'?([\w\s\d_.-]+?)'?$",
             re.IGNORECASE
        )


        # Pattern for filtering based on tensor value at a specific index: "... from Y where tensor_value[index] op value"
        # Captures: dataset_name, index, operator, value
        self.pattern_filter_tensor = re.compile(
            r"^(?:get|show|find)\s+.*\s+from\s+([\w_.-]+)\s+where\s+(?:tensor|value)\s*(?:\[(\d+)\])?\s*([<>=!]+)\s*([\d.-]+)$",
            re.IGNORECASE
        )
        # Note: tensor[index] assumes a 1D tensor or accessing the element at flat index `index`.
        # More complex tensor indexing (e.g., tensor[0, 1]) is not supported by this simple regex.

        # Pattern for counting records: "count [records/...] in Y"
        self.pattern_count = re.compile(
            r"^count\s+(?:records?|entries|experiences?)\s+(?:in|from)\s+(?:dataset\s+)?([\w_.-]+)$",
             re.IGNORECASE
        )

        logger.info("NQLAgent initialized with basic regex patterns.")


    def _parse_operator_and_value(self, op_str: str, val_str: str) -> Tuple[Callable, Any]:
        """Attempts to parse operator string and convert value string to number if possible."""
        val_str = val_str.strip()
        op_map = {
            '=': lambda a, b: a == b,
            '==': lambda a, b: a == b,
            '!=': lambda a, b: a != b,
            '<': lambda a, b: a < b,
            '<=': lambda a, b: a <= b,
            '>': lambda a, b: a > b,
            '>=': lambda a, b: a >= b,
        }

        op_func = op_map.get(op_str)
        if op_func is None:
            raise ValueError(f"Unsupported operator: {op_str}")

        # Try converting value to float or int
        try:
            value = float(val_str)
            if value.is_integer():
                value = int(value)
        except ValueError:
            # Keep as string if conversion fails
            value = val_str

        return op_func, value


    def process_query(self, query: str) -> Dict[str, Any]:
        """
        Processes a natural language query string.

        Args:
            query: The natural language query.

        Returns:
            A dictionary containing:
                'success': bool, indicating if the query was processed.
                'message': str, status message or error description.
                'count': Optional[int], number of results found.
                'results': Optional[List[Dict]], the list of matching records
                           (each a dict with 'tensor' and 'metadata').
        """
        query = query.strip()
        logger.info(f"Processing NQL query: '{query}'")

        # Try matching patterns in order of specificity/complexity

        # --- 1. Count Pattern ---
        match = self.pattern_count.match(query)
        if match:
            dataset_name = match.group(1)
            logger.debug(f"Matched COUNT pattern for dataset '{dataset_name}'")
            try:
                # Inefficient: gets all data just to count.
                # TODO: Add count method to TensorStorage
                results = self.tensor_storage.get_dataset_with_metadata(dataset_name)
                count = len(results)
                return {
                    "success": True,
                    "message": f"Found {count} records in dataset '{dataset_name}'.",
                    "count": count,
                    "results": None # Or optionally return all results if needed?
                }
            except ValueError as e:
                logger.error(f"Error during COUNT query: {e}")
                return {"success": False, "message": str(e), "count": None, "results": None}
            except Exception as e:
                 logger.error(f"Unexpected error during COUNT query: {e}", exc_info=True)
                 return {"success": False, "message": f"An unexpected error occurred: {e}", "count": None, "results": None}


        # --- 2. Get All Pattern ---
        match = self.pattern_get_all.match(query)
        if match:
            dataset_name = match.group(1)
            logger.debug(f"Matched GET ALL pattern for dataset '{dataset_name}'")
            try:
                results = self.tensor_storage.get_dataset_with_metadata(dataset_name)
                count = len(results)
                return {
                    "success": True,
                    "message": f"Retrieved {count} records from dataset '{dataset_name}'.",
                    "count": count,
                    "results": results
                }
            except ValueError as e:
                logger.error(f"Error during GET ALL query: {e}")
                return {"success": False, "message": str(e), "count": None, "results": None}
            except Exception as e:
                 logger.error(f"Unexpected error during GET ALL query: {e}", exc_info=True)
                 return {"success": False, "message": f"An unexpected error occurred: {e}", "count": None, "results": None}


        # --- 3. Filter Metadata Pattern ---
        match_meta = self.pattern_filter_meta.match(query)
        if not match_meta: # Try alternative pattern if first fails
             match_meta = self.pattern_filter_meta_alt.match(query)
             if match_meta:
                  # Extract groups and manually set operator to '='
                  dataset_name = match_meta.group(1)
                  key = match_meta.group(2)
                  op_str = '=' # Implicitly '=' for 'is/equals/eq'
                  val_str = match_meta.group(3)
                  logger.debug(f"Matched FILTER META ALT pattern: dataset='{dataset_name}', key='{key}', op='{op_str}', value='{val_str}'")
             else:
                   match_meta = None # Reset if alt pattern didn't match either
        else:
            # Standard extraction
            dataset_name = match_meta.group(1)
            key = match_meta.group(2)
            op_str = match_meta.group(3)
            val_str = match_meta.group(4)
            logger.debug(f"Matched FILTER META pattern: dataset='{dataset_name}', key='{key}', op='{op_str}', value='{val_str}'")


        if match_meta:
             try:
                 op_func, filter_value = self._parse_operator_and_value(op_str, val_str)

                 # Construct the query function dynamically
                 def query_fn_meta(tensor: torch.Tensor, metadata: Dict[str, Any]) -> bool:
                     actual_value = metadata.get(key)
                     if actual_value is None:
                         return False # Key doesn't exist in this record's metadata

                     # Attempt type coercion if filter value is numeric but actual is not
                     # (Basic attempt, might need more robust type handling)
                     try:
                          if isinstance(filter_value, (int, float)) and not isinstance(actual_value, (int, float)):
                              actual_value = type(filter_value)(actual_value) # Try to cast actual to filter type
                          elif isinstance(filter_value, str) and not isinstance(actual_value, str):
                              actual_value = str(actual_value) # Cast actual to string if filter is string
                     except (ValueError, TypeError):
                          return False # Cannot compare types

                     try:
                         return op_func(actual_value, filter_value)
                     except TypeError:
                          # Mismatched types that couldn't be coerced
                          return False
                     except Exception as e_inner:
                         logger.warning(f"Error during query_fn execution for key '{key}': {e_inner}")
                         return False

                 results = self.tensor_storage.query(dataset_name, query_fn_meta)
                 count = len(results)
                 return {
                     "success": True,
                     "message": f"Query executed successfully. Found {count} matching records.",
                     "count": count,
                     "results": results
                 }

             except ValueError as e: # Catches parse_operator_and_value errors or dataset not found
                 logger.error(f"Error processing FILTER META query: {e}")
                 return {"success": False, "message": str(e), "count": None, "results": None}
             except Exception as e:
                 logger.error(f"Unexpected error during FILTER META query: {e}", exc_info=True)
                 return {"success": False, "message": f"An unexpected error occurred: {e}", "count": None, "results": None}


        # --- 4. Filter Tensor Pattern ---
        match = self.pattern_filter_tensor.match(query)
        if match:
            dataset_name = match.group(1)
            index_str = match.group(2) # Might be None if accessing whole tensor value (ambiguous)
            op_str = match.group(3)
            val_str = match.group(4)
            logger.debug(f"Matched FILTER TENSOR pattern: dataset='{dataset_name}', index='{index_str}', op='{op_str}', value='{val_str}'")

            try:
                op_func, filter_value = self._parse_operator_and_value(op_str, val_str)
                if not isinstance(filter_value, (int, float)):
                    raise ValueError(f"Tensor value filtering currently only supports numeric comparisons. Got value: {filter_value}")

                tensor_index: Optional[int] = None
                if index_str is not None:
                    tensor_index = int(index_str)

                # Construct the query function dynamically
                def query_fn_tensor(tensor: torch.Tensor, metadata: Dict[str, Any]) -> bool:
                    try:
                        if tensor_index is None:
                           # Ambiguous case: compare filter value against the whole tensor?
                           # Requires defining how comparison works (e.g., any element, all elements?)
                           # For now, let's assume it compares against the first element if tensor is 1D+ and index omitted
                           if tensor.numel() > 0:
                                actual_value = tensor.view(-1)[0].item() # Get first element's value
                           else:
                                return False # Empty tensor cannot satisfy condition
                        else:
                             # Access element at specified index (flattened tensor)
                             if tensor_index >= tensor.numel():
                                 return False # Index out of bounds
                             actual_value = tensor.view(-1)[tensor_index].item()

                        # Value comparison (assuming numeric)
                        return op_func(actual_value, filter_value)

                    except IndexError:
                        return False # Index out of bounds
                    except Exception as e_inner:
                        logger.warning(f"Error during query_fn_tensor execution: {e_inner}")
                        return False

                results = self.tensor_storage.query(dataset_name, query_fn_tensor)
                count = len(results)
                return {
                    "success": True,
                    "message": f"Query executed successfully. Found {count} matching records.",
                    "count": count,
                    "results": results
                }

            except ValueError as e: # Catches parse errors, type errors, dataset not found etc.
                logger.error(f"Error processing FILTER TENSOR query: {e}")
                return {"success": False, "message": str(e), "count": None, "results": None}
            except Exception as e:
                logger.error(f"Unexpected error during FILTER TENSOR query: {e}", exc_info=True)
                return {"success": False, "message": f"An unexpected error occurred: {e}", "count": None, "results": None}


        # --- No Match Found ---
        logger.warning(f"Query did not match any known patterns: '{query}'")
        return {
            "success": False,
            "message": "Sorry, I couldn't understand that query. Try simple commands like 'get all data from my_dataset' or 'find records from my_dataset where key = value'.",
            "count": None,
            "results": None
        }


# --- Example Usage ---
if __name__ == "__main__":
    print("--- Starting NQL Agent Example ---")

    # 1. Setup TensorStorage and add some data
    storage = TensorStorage()
    storage.create_dataset("sensor_data")
    storage.create_dataset("rl_experiences_test")

    storage.insert("sensor_data", torch.tensor([10.5, 25.2]), metadata={"sensor_id": "A001", "location": "floor1", "status":"active"})
    storage.insert("sensor_data", torch.tensor([12.1, 26.8]), metadata={"sensor_id": "A002", "location": "floor1", "status":"active"})
    storage.insert("sensor_data", torch.tensor([-5.0, 24.1]), metadata={"sensor_id": "B001", "location": "floor2", "status":"inactive"})

    # Add dummy experience data (using the structure from RLAgent)
    storage.insert("rl_experiences_test", torch.tensor([1.0]), metadata={"state_id": "s1", "action": 0, "reward": -1.5, "next_state_id": "s2", "done": 0})
    storage.insert("rl_experiences_test", torch.tensor([1.0]), metadata={"state_id": "s2", "action": 1, "reward": 5.2, "next_state_id": "s3", "done": 0})
    storage.insert("rl_experiences_test", torch.tensor([1.0]), metadata={"state_id": "s3", "action": 0, "reward": -8.0, "next_state_id": None, "done": 1})


    # 2. Create the NQL Agent
    nql_agent = NQLAgent(storage)

    # 3. Define test queries
    queries = [
        "get all data from sensor_data",
        "show all records from rl_experiences_test",
        "count records in sensor_data",
        "find tensors from sensor_data where sensor_id = 'A001'", # Metadata string eq
        "find data from sensor_data where location is 'floor1'", # Metadata string eq (alt syntax)
        "get records from sensor_data where status != 'active'", # Metadata string neq
        "find experiences from rl_experiences_test where reward > 0", # Metadata numeric gt
        "get experiences from rl_experiences_test where reward < -5", # Metadata numeric lt
        "find entries from rl_experiences_test where done == 1", # Metadata numeric eq (bool as int)
        "get records from sensor_data where value[0] > 11", # Tensor value gt at index 0
        "find tensors from sensor_data where tensor[1] < 25", # Tensor value lt at index 1
        "show data from sensor_data where value = -5.0", # Tensor value eq (omitting index -> first element)
        "get everything from non_existent_dataset", # Test non-existent dataset
        "find data from sensor_data where invalid_key = 10", # Test non-existent key
        "give me the average sensor reading", # Test unsupported query
        "select * from sensor_data" # Test SQL-like query (unsupported)
    ]

    # 4. Process queries and print results
    print("\n--- Processing Queries ---")
    for q in queries:
        print(f"\n> Query: \"{q}\"")
        response = nql_agent.process_query(q)
        print(f"< Success: {response['success']}")
        print(f"< Message: {response['message']}")
        if response['success'] and response['results'] is not None:
            print(f"< Count: {response['count']}")
            # Print limited results for brevity
            limit = 3
            for i, item in enumerate(response['results']):
                if i >= limit:
                    print(f"  ... (omitting {len(response['results']) - limit} more results)")
                    break
                # Simplify tensor printing for readability
                tensor_str = f"Tensor(shape={item['tensor'].shape}, dtype={item['tensor'].dtype})"
                print(f"  - Result {i+1}: Metadata={item['metadata']}, Tensor={tensor_str}")
        elif response['success'] and response['count'] is not None:
             print(f"< Count: {response['count']}") # For count queries


    print("\n--- NQL Agent Example Finished ---")