Spaces:
Running
Running
| import pytest | |
| pytest.importorskip("torch") | |
| from unittest.mock import patch | |
| import torch | |
| from tensorus.nql_agent import NQLAgent | |
| from tensorus.tensor_storage import TensorStorage | |
| def basic_storage(): | |
| storage = TensorStorage(storage_path=None) # In-memory for tests | |
| storage.create_dataset("ds1") | |
| storage.insert("ds1", torch.tensor([1.0]), metadata={"id": "007", "name": "Sensor Alpha", "status": "active", "reading": 100, "level": 5.5, "is_calibrated": True}) | |
| storage.insert("ds1", torch.tensor([2.0]), metadata={"id": "008", "name": "Sensor Beta", "status": "inactive", "reading": 200, "level": 5.5, "is_calibrated": False}) | |
| storage.insert("ds1", torch.tensor([3.0]), metadata={"id": "ABC", "name": "Sensor Gamma", "status": "active", "reading": 150, "level": 6.0}) # Missing is_calibrated | |
| storage.insert("ds1", torch.tensor([4.0]), metadata={"id": "009", "name": "Sensor Delta", "status": "active", "reading": "N/A"}) # Reading as string | |
| storage.create_dataset("ds2") | |
| storage.insert("ds2", torch.tensor([0.1, 0.2]), metadata={"record_id": "rec1", "value": 10}) | |
| storage.insert("ds2", torch.tensor([0.3, 0.4, 0.5]), metadata={"record_id": "rec2", "value": 20}) | |
| storage.insert("ds2", torch.tensor([]), metadata={"record_id": "rec3", "value": 0}) # Empty tensor | |
| return storage | |
| def test_count_query(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("count records in ds1") | |
| assert result["success"] | |
| assert result["count"] == 4 | |
| result = agent.process_query("count records from ds2") | |
| assert result["success"] | |
| assert result["count"] == 3 | |
| def test_count_query_missing_dataset(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("count records in missing_ds") | |
| assert not result["success"] | |
| assert result["message"] == "Dataset 'missing_ds' not found." | |
| assert result["count"] is None | |
| def test_get_all_query(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("get all data from ds1") | |
| assert result["success"] | |
| assert result["count"] == 4 | |
| assert len(result["results"]) == 4 | |
| assert result["results"][0]["metadata"]["id"] == "007" | |
| result = agent.process_query("show all tensors from ds2") | |
| assert result["success"] | |
| assert result["count"] == 3 | |
| assert len(result["results"]) == 3 | |
| def test_get_all_query_missing_dataset(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("get all records from missing_ds") | |
| assert not result["success"] | |
| # This error might come from TensorStorage, so message could vary. | |
| # For now, we check success is False. A more specific message check might be needed | |
| # if TensorStorage guarantees specific error messages for DatasetNotFoundError. | |
| # Based on NQLAgent's current `process_query` for get_all, it might be a ValueError. | |
| # The count pattern has explicit DatasetNotFoundError handling. Get all doesn't yet. | |
| # Let's assume for now it's a generic message or add specific handling later. | |
| assert "missing_ds" in result["message"].lower() # Check if dataset name is in message | |
| # --- Metadata Filter Tests (Focus of T1 from planning) --- | |
| def test_metadata_filtering(basic_storage, query, expected_ids): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query(query) | |
| assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}" | |
| assert result["count"] == len(expected_ids) | |
| found_ids = sorted([res["metadata"]["id"] for res in result["results"]]) | |
| assert found_ids == sorted(expected_ids) | |
| def test_filter_meta_key_not_exist(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| query = "find records from ds1 where non_existent_key = 'value'" | |
| result = agent.process_query(query) | |
| assert result["success"] | |
| assert result["count"] == 0 | |
| def test_filter_meta_partially_existing_key(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| # 'is_calibrated' is missing from one record ("ABC") | |
| query = "find records from ds1 where is_calibrated = 'True'" | |
| result = agent.process_query(query) | |
| assert result["success"] | |
| assert result["count"] == 1 | |
| assert result["results"][0]["metadata"]["id"] == "007" | |
| # --- Tensor Filter Tests --- | |
| def test_tensor_filtering(basic_storage, query, expected_ids_values): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query(query) | |
| assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}" | |
| assert result["count"] == len(expected_ids_values) | |
| found_ids = sorted([res["metadata"]["record_id"] for res in result["results"]]) | |
| assert found_ids == sorted(expected_ids_values.keys()) | |
| # Further checks on actual tensor values could be added if needed | |
| def test_filter_tensor_index_out_of_bounds(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| query = "get records from ds2 where tensor[5] > 0" | |
| result = agent.process_query(query) | |
| assert result["success"] | |
| assert result["count"] == 0 | |
| def test_filter_tensor_non_numeric_value_in_query(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| query = "get records from ds2 where tensor[0] > 'abc'" | |
| result = agent.process_query(query) | |
| assert not result["success"] | |
| assert "Tensor value filtering currently only supports numeric comparisons" in result["message"] | |
| # --- General Error and Edge Case Tests --- | |
| def test_filter_meta_query_missing_dataset(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("find records from missing_ds where key = 'val'") | |
| assert not result["success"] | |
| assert "missing_ds" in result["message"].lower() | |
| def test_filter_tensor_query_missing_dataset(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("show data from missing_ds where tensor[0] > 0") | |
| assert not result["success"] | |
| assert "missing_ds" in result["message"].lower() | |
| def test_unsupported_query_format(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| result = agent.process_query("this is not a valid query") | |
| assert not result["success"] | |
| assert "Sorry, I couldn't understand that query" in result["message"] | |
| def test_query_with_extra_spacing(basic_storage): | |
| agent = NQLAgent(basic_storage) | |
| # Test a metadata filter query with extra spaces | |
| query = "find records from ds1 where id = '007'" | |
| result = agent.process_query(query) | |
| assert result["success"] | |
| assert result["count"] == 1 | |
| assert result["results"][0]["metadata"]["id"] == "007" | |
| # Test a count query | |
| query = "count records in ds1" | |
| result = agent.process_query(query) | |
| assert result["success"] | |
| assert result["count"] == 4 | |