core / tests /test_nql_agent_basic.py
tensorus's picture
Upload 83 files
edfa748 verified
import pytest
pytest.importorskip("torch")
from unittest.mock import patch
import torch
from tensorus.nql_agent import NQLAgent
from tensorus.tensor_storage import TensorStorage
@pytest.fixture
def basic_storage():
storage = TensorStorage(storage_path=None) # In-memory for tests
storage.create_dataset("ds1")
storage.insert("ds1", torch.tensor([1.0]), metadata={"id": "007", "name": "Sensor Alpha", "status": "active", "reading": 100, "level": 5.5, "is_calibrated": True})
storage.insert("ds1", torch.tensor([2.0]), metadata={"id": "008", "name": "Sensor Beta", "status": "inactive", "reading": 200, "level": 5.5, "is_calibrated": False})
storage.insert("ds1", torch.tensor([3.0]), metadata={"id": "ABC", "name": "Sensor Gamma", "status": "active", "reading": 150, "level": 6.0}) # Missing is_calibrated
storage.insert("ds1", torch.tensor([4.0]), metadata={"id": "009", "name": "Sensor Delta", "status": "active", "reading": "N/A"}) # Reading as string
storage.create_dataset("ds2")
storage.insert("ds2", torch.tensor([0.1, 0.2]), metadata={"record_id": "rec1", "value": 10})
storage.insert("ds2", torch.tensor([0.3, 0.4, 0.5]), metadata={"record_id": "rec2", "value": 20})
storage.insert("ds2", torch.tensor([]), metadata={"record_id": "rec3", "value": 0}) # Empty tensor
return storage
def test_count_query(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("count records in ds1")
assert result["success"]
assert result["count"] == 4
result = agent.process_query("count records from ds2")
assert result["success"]
assert result["count"] == 3
def test_count_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("count records in missing_ds")
assert not result["success"]
assert result["message"] == "Dataset 'missing_ds' not found."
assert result["count"] is None
def test_get_all_query(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("get all data from ds1")
assert result["success"]
assert result["count"] == 4
assert len(result["results"]) == 4
assert result["results"][0]["metadata"]["id"] == "007"
result = agent.process_query("show all tensors from ds2")
assert result["success"]
assert result["count"] == 3
assert len(result["results"]) == 3
def test_get_all_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("get all records from missing_ds")
assert not result["success"]
# This error might come from TensorStorage, so message could vary.
# For now, we check success is False. A more specific message check might be needed
# if TensorStorage guarantees specific error messages for DatasetNotFoundError.
# Based on NQLAgent's current `process_query` for get_all, it might be a ValueError.
# The count pattern has explicit DatasetNotFoundError handling. Get all doesn't yet.
# Let's assume for now it's a generic message or add specific handling later.
assert "missing_ds" in result["message"].lower() # Check if dataset name is in message
# --- Metadata Filter Tests (Focus of T1 from planning) ---
@pytest.mark.parametrize("query, expected_ids", [
# String ID comparisons (id is string "007", "008", "ABC", "009")
("find records from ds1 where id = '007'", ["007"]),
("find records from ds1 where id == \"007\"", ["007"]), # Double quotes
("find records from ds1 where id = 'ABC'", ["ABC"]),
("find records from ds1 where id = 007", []), # Unquoted "007" becomes num 7, should not match string "007"
("find records from ds1 where id != '007'", ["008", "ABC", "009"]),
("find records from ds1 where id != 007", ["007", "008", "ABC", "009"]), # num 7 != string "007" (coercion to string then compare)
# Numeric value comparisons (reading is number 100, 200, 150; level is float 5.5, 6.0)
("find records from ds1 where reading = 100", ["007"]),
("find records from ds1 where reading == 150", ["ABC"]),
("find records from ds1 where reading = '100'", ["007"]), # Query "100" (str) vs metadata 100 (int) -> coercion
("find records from ds1 where reading > 100", ["008", "ABC"]),
("find records from ds1 where reading >= 150", ["008", "ABC"]),
("find records from ds1 where reading < 150", ["007"]),
("find records from ds1 where reading <= 100", ["007"]),
("find records from ds1 where reading != 100", ["008", "ABC", "009"]), # "009" has reading "N/A"
("find records from ds1 where level = 5.5", ["007", "008"]),
("find records from ds1 where level > 5.5", ["ABC"]),
("find records from ds1 where level = '5.5'", ["007", "008"]), # Query "5.5" (str) vs metadata 5.5 (float)
# String value comparisons (status is string "active", "inactive")
("find records from ds1 where status = 'active'", ["007", "ABC", "009"]),
("find records from ds1 where status is 'inactive'", ["008"]), # alt syntax
("find records from ds1 where status equals \"active\"", ["007", "ABC", "009"]), # alt syntax
("find records from ds1 where status eq 'active'", ["007", "ABC", "009"]), # alt syntax
# Boolean value comparisons (is_calibrated is True, False, or missing)
# Current _parse_operator_and_value will treat True/False as unquoted strings if not careful.
# With the fix, unquoted True/False will be strings "True"/"False".
# Coercion logic in query_fn_meta: if filter_value is str, actual_value becomes str.
("find records from ds1 where is_calibrated = 'True'", ["007"]), # actual True -> "True" vs "True"
("find records from ds1 where is_calibrated = 'False'", ["008"]),# actual False -> "False" vs "False"
# ("find records from ds1 where is_calibrated = True", ["007"]), # This query would make filter_value string "True"
# ("find records from ds1 where is_calibrated = 1", []), # is_calibrated (bool) vs 1 (int) - needs specific handling or relies on coercion
# Handling of metadata key "reading" which is string "N/A" for one record
("find records from ds1 where reading = 'N/A'", ["009"]),
("find records from ds1 where reading != 'N/A'", ["007", "008", "ABC"]),
("find records from ds1 where reading > 1000", []), # "N/A" should not satisfy > numeric
("find records from ds1 where reading = 0", []), # "N/A" should not satisfy = numeric
])
def test_metadata_filtering(basic_storage, query, expected_ids):
agent = NQLAgent(basic_storage)
result = agent.process_query(query)
assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}"
assert result["count"] == len(expected_ids)
found_ids = sorted([res["metadata"]["id"] for res in result["results"]])
assert found_ids == sorted(expected_ids)
def test_filter_meta_key_not_exist(basic_storage):
agent = NQLAgent(basic_storage)
query = "find records from ds1 where non_existent_key = 'value'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 0
def test_filter_meta_partially_existing_key(basic_storage):
agent = NQLAgent(basic_storage)
# 'is_calibrated' is missing from one record ("ABC")
query = "find records from ds1 where is_calibrated = 'True'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 1
assert result["results"][0]["metadata"]["id"] == "007"
# --- Tensor Filter Tests ---
@pytest.mark.parametrize("query, expected_ids_values", [
# ds2: rec1: [0.1, 0.2], val:10; rec2: [0.3, 0.4, 0.5], val:20; rec3: [], val:0
("get records from ds2 where tensor[0] > 0.05", {"rec1": 0.1, "rec2": 0.3}),
("get records from ds2 where tensor[1] < 0.3", {"rec1": 0.2}),
("get records from ds2 where tensor[2] = 0.5", {"rec2": 0.5}),
("get records from ds2 where tensor[0] = 0.1", {"rec1": 0.1}),
# Index-less queries (default to first element)
("get records from ds2 where tensor > 0.2", {"rec2": 0.3}), # rec1 fails (0.1), rec2 matches (0.3)
("get records from ds2 where tensor = 0.1", {"rec1": 0.1}),
# Empty tensor record (rec3) should not match positive conditions, or error
("get records from ds2 where tensor[0] > 0", {"rec1": 0.1, "rec2": 0.3}), # rec3's empty tensor fails index access
("get records from ds2 where tensor > 0", {"rec1": 0.1, "rec2": 0.3}), # rec3's empty tensor fails first element access
])
def test_tensor_filtering(basic_storage, query, expected_ids_values):
agent = NQLAgent(basic_storage)
result = agent.process_query(query)
assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}"
assert result["count"] == len(expected_ids_values)
found_ids = sorted([res["metadata"]["record_id"] for res in result["results"]])
assert found_ids == sorted(expected_ids_values.keys())
# Further checks on actual tensor values could be added if needed
def test_filter_tensor_index_out_of_bounds(basic_storage):
agent = NQLAgent(basic_storage)
query = "get records from ds2 where tensor[5] > 0"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 0
def test_filter_tensor_non_numeric_value_in_query(basic_storage):
agent = NQLAgent(basic_storage)
query = "get records from ds2 where tensor[0] > 'abc'"
result = agent.process_query(query)
assert not result["success"]
assert "Tensor value filtering currently only supports numeric comparisons" in result["message"]
# --- General Error and Edge Case Tests ---
def test_filter_meta_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("find records from missing_ds where key = 'val'")
assert not result["success"]
assert "missing_ds" in result["message"].lower()
def test_filter_tensor_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("show data from missing_ds where tensor[0] > 0")
assert not result["success"]
assert "missing_ds" in result["message"].lower()
def test_unsupported_query_format(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("this is not a valid query")
assert not result["success"]
assert "Sorry, I couldn't understand that query" in result["message"]
def test_query_with_extra_spacing(basic_storage):
agent = NQLAgent(basic_storage)
# Test a metadata filter query with extra spaces
query = "find records from ds1 where id = '007'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 1
assert result["results"][0]["metadata"]["id"] == "007"
# Test a count query
query = "count records in ds1"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 4