| import pytest |
| pytest.importorskip("torch") |
| from unittest.mock import patch |
| import torch |
|
|
| from tensorus.nql_agent import NQLAgent |
| from tensorus.tensor_storage import TensorStorage |
|
|
| @pytest.fixture |
| def basic_storage(): |
| storage = TensorStorage(storage_path=None) |
| storage.create_dataset("ds1") |
| storage.insert("ds1", torch.tensor([1.0]), metadata={"id": "007", "name": "Sensor Alpha", "status": "active", "reading": 100, "level": 5.5, "is_calibrated": True}) |
| storage.insert("ds1", torch.tensor([2.0]), metadata={"id": "008", "name": "Sensor Beta", "status": "inactive", "reading": 200, "level": 5.5, "is_calibrated": False}) |
| storage.insert("ds1", torch.tensor([3.0]), metadata={"id": "ABC", "name": "Sensor Gamma", "status": "active", "reading": 150, "level": 6.0}) |
| storage.insert("ds1", torch.tensor([4.0]), metadata={"id": "009", "name": "Sensor Delta", "status": "active", "reading": "N/A"}) |
|
|
| storage.create_dataset("ds2") |
| storage.insert("ds2", torch.tensor([0.1, 0.2]), metadata={"record_id": "rec1", "value": 10}) |
| storage.insert("ds2", torch.tensor([0.3, 0.4, 0.5]), metadata={"record_id": "rec2", "value": 20}) |
| storage.insert("ds2", torch.tensor([]), metadata={"record_id": "rec3", "value": 0}) |
|
|
| return storage |
|
|
|
|
| def test_count_query(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("count records in ds1") |
| assert result["success"] |
| assert result["count"] == 4 |
|
|
| result = agent.process_query("count records from ds2") |
| assert result["success"] |
| assert result["count"] == 3 |
|
|
|
|
| def test_count_query_missing_dataset(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("count records in missing_ds") |
| assert not result["success"] |
| assert result["message"] == "Dataset 'missing_ds' not found." |
| assert result["count"] is None |
|
|
|
|
| def test_get_all_query(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("get all data from ds1") |
| assert result["success"] |
| assert result["count"] == 4 |
| assert len(result["results"]) == 4 |
| assert result["results"][0]["metadata"]["id"] == "007" |
|
|
| result = agent.process_query("show all tensors from ds2") |
| assert result["success"] |
| assert result["count"] == 3 |
| assert len(result["results"]) == 3 |
|
|
|
|
| def test_get_all_query_missing_dataset(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("get all records from missing_ds") |
| assert not result["success"] |
| |
| |
| |
| |
| |
| |
| assert "missing_ds" in result["message"].lower() |
|
|
|
|
| |
|
|
| @pytest.mark.parametrize("query, expected_ids", [ |
| |
| ("find records from ds1 where id = '007'", ["007"]), |
| ("find records from ds1 where id == \"007\"", ["007"]), |
| ("find records from ds1 where id = 'ABC'", ["ABC"]), |
| ("find records from ds1 where id = 007", []), |
| ("find records from ds1 where id != '007'", ["008", "ABC", "009"]), |
| ("find records from ds1 where id != 007", ["007", "008", "ABC", "009"]), |
| |
| |
| ("find records from ds1 where reading = 100", ["007"]), |
| ("find records from ds1 where reading == 150", ["ABC"]), |
| ("find records from ds1 where reading = '100'", ["007"]), |
| ("find records from ds1 where reading > 100", ["008", "ABC"]), |
| ("find records from ds1 where reading >= 150", ["008", "ABC"]), |
| ("find records from ds1 where reading < 150", ["007"]), |
| ("find records from ds1 where reading <= 100", ["007"]), |
| ("find records from ds1 where reading != 100", ["008", "ABC", "009"]), |
| ("find records from ds1 where level = 5.5", ["007", "008"]), |
| ("find records from ds1 where level > 5.5", ["ABC"]), |
| ("find records from ds1 where level = '5.5'", ["007", "008"]), |
| |
| |
| ("find records from ds1 where status = 'active'", ["007", "ABC", "009"]), |
| ("find records from ds1 where status is 'inactive'", ["008"]), |
| ("find records from ds1 where status equals \"active\"", ["007", "ABC", "009"]), |
| ("find records from ds1 where status eq 'active'", ["007", "ABC", "009"]), |
| |
| |
| |
| |
| |
| ("find records from ds1 where is_calibrated = 'True'", ["007"]), |
| ("find records from ds1 where is_calibrated = 'False'", ["008"]), |
| |
| |
| |
| |
| ("find records from ds1 where reading = 'N/A'", ["009"]), |
| ("find records from ds1 where reading != 'N/A'", ["007", "008", "ABC"]), |
| ("find records from ds1 where reading > 1000", []), |
| ("find records from ds1 where reading = 0", []), |
| ]) |
| def test_metadata_filtering(basic_storage, query, expected_ids): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query(query) |
| assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}" |
| assert result["count"] == len(expected_ids) |
|
|
| found_ids = sorted([res["metadata"]["id"] for res in result["results"]]) |
| assert found_ids == sorted(expected_ids) |
|
|
| def test_filter_meta_key_not_exist(basic_storage): |
| agent = NQLAgent(basic_storage) |
| query = "find records from ds1 where non_existent_key = 'value'" |
| result = agent.process_query(query) |
| assert result["success"] |
| assert result["count"] == 0 |
|
|
| def test_filter_meta_partially_existing_key(basic_storage): |
| agent = NQLAgent(basic_storage) |
| |
| query = "find records from ds1 where is_calibrated = 'True'" |
| result = agent.process_query(query) |
| assert result["success"] |
| assert result["count"] == 1 |
| assert result["results"][0]["metadata"]["id"] == "007" |
|
|
| |
|
|
| @pytest.mark.parametrize("query, expected_ids_values", [ |
| |
| ("get records from ds2 where tensor[0] > 0.05", {"rec1": 0.1, "rec2": 0.3}), |
| ("get records from ds2 where tensor[1] < 0.3", {"rec1": 0.2}), |
| ("get records from ds2 where tensor[2] = 0.5", {"rec2": 0.5}), |
| ("get records from ds2 where tensor[0] = 0.1", {"rec1": 0.1}), |
| |
| ("get records from ds2 where tensor > 0.2", {"rec2": 0.3}), |
| ("get records from ds2 where tensor = 0.1", {"rec1": 0.1}), |
| |
| ("get records from ds2 where tensor[0] > 0", {"rec1": 0.1, "rec2": 0.3}), |
| ("get records from ds2 where tensor > 0", {"rec1": 0.1, "rec2": 0.3}), |
| ]) |
| def test_tensor_filtering(basic_storage, query, expected_ids_values): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query(query) |
| assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}" |
| assert result["count"] == len(expected_ids_values) |
|
|
| found_ids = sorted([res["metadata"]["record_id"] for res in result["results"]]) |
| assert found_ids == sorted(expected_ids_values.keys()) |
| |
|
|
| def test_filter_tensor_index_out_of_bounds(basic_storage): |
| agent = NQLAgent(basic_storage) |
| query = "get records from ds2 where tensor[5] > 0" |
| result = agent.process_query(query) |
| assert result["success"] |
| assert result["count"] == 0 |
|
|
| def test_filter_tensor_non_numeric_value_in_query(basic_storage): |
| agent = NQLAgent(basic_storage) |
| query = "get records from ds2 where tensor[0] > 'abc'" |
| result = agent.process_query(query) |
| assert not result["success"] |
| assert "Tensor value filtering currently only supports numeric comparisons" in result["message"] |
|
|
| |
| def test_filter_meta_query_missing_dataset(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("find records from missing_ds where key = 'val'") |
| assert not result["success"] |
| assert "missing_ds" in result["message"].lower() |
|
|
| def test_filter_tensor_query_missing_dataset(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("show data from missing_ds where tensor[0] > 0") |
| assert not result["success"] |
| assert "missing_ds" in result["message"].lower() |
|
|
| def test_unsupported_query_format(basic_storage): |
| agent = NQLAgent(basic_storage) |
| result = agent.process_query("this is not a valid query") |
| assert not result["success"] |
| assert "Sorry, I couldn't understand that query" in result["message"] |
|
|
| def test_query_with_extra_spacing(basic_storage): |
| agent = NQLAgent(basic_storage) |
| |
| query = "find records from ds1 where id = '007'" |
| result = agent.process_query(query) |
| assert result["success"] |
| assert result["count"] == 1 |
| assert result["results"][0]["metadata"]["id"] == "007" |
|
|
| |
| query = "count records in ds1" |
| result = agent.process_query(query) |
| assert result["success"] |
| assert result["count"] == 4 |
|
|