Spaces:
Running
Running
File size: 11,147 Bytes
edfa748 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | import pytest
pytest.importorskip("torch")
from unittest.mock import patch
import torch
from tensorus.nql_agent import NQLAgent
from tensorus.tensor_storage import TensorStorage
@pytest.fixture
def basic_storage():
storage = TensorStorage(storage_path=None) # In-memory for tests
storage.create_dataset("ds1")
storage.insert("ds1", torch.tensor([1.0]), metadata={"id": "007", "name": "Sensor Alpha", "status": "active", "reading": 100, "level": 5.5, "is_calibrated": True})
storage.insert("ds1", torch.tensor([2.0]), metadata={"id": "008", "name": "Sensor Beta", "status": "inactive", "reading": 200, "level": 5.5, "is_calibrated": False})
storage.insert("ds1", torch.tensor([3.0]), metadata={"id": "ABC", "name": "Sensor Gamma", "status": "active", "reading": 150, "level": 6.0}) # Missing is_calibrated
storage.insert("ds1", torch.tensor([4.0]), metadata={"id": "009", "name": "Sensor Delta", "status": "active", "reading": "N/A"}) # Reading as string
storage.create_dataset("ds2")
storage.insert("ds2", torch.tensor([0.1, 0.2]), metadata={"record_id": "rec1", "value": 10})
storage.insert("ds2", torch.tensor([0.3, 0.4, 0.5]), metadata={"record_id": "rec2", "value": 20})
storage.insert("ds2", torch.tensor([]), metadata={"record_id": "rec3", "value": 0}) # Empty tensor
return storage
def test_count_query(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("count records in ds1")
assert result["success"]
assert result["count"] == 4
result = agent.process_query("count records from ds2")
assert result["success"]
assert result["count"] == 3
def test_count_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("count records in missing_ds")
assert not result["success"]
assert result["message"] == "Dataset 'missing_ds' not found."
assert result["count"] is None
def test_get_all_query(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("get all data from ds1")
assert result["success"]
assert result["count"] == 4
assert len(result["results"]) == 4
assert result["results"][0]["metadata"]["id"] == "007"
result = agent.process_query("show all tensors from ds2")
assert result["success"]
assert result["count"] == 3
assert len(result["results"]) == 3
def test_get_all_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("get all records from missing_ds")
assert not result["success"]
# This error might come from TensorStorage, so message could vary.
# For now, we check success is False. A more specific message check might be needed
# if TensorStorage guarantees specific error messages for DatasetNotFoundError.
# Based on NQLAgent's current `process_query` for get_all, it might be a ValueError.
# The count pattern has explicit DatasetNotFoundError handling. Get all doesn't yet.
# Let's assume for now it's a generic message or add specific handling later.
assert "missing_ds" in result["message"].lower() # Check if dataset name is in message
# --- Metadata Filter Tests (Focus of T1 from planning) ---
@pytest.mark.parametrize("query, expected_ids", [
# String ID comparisons (id is string "007", "008", "ABC", "009")
("find records from ds1 where id = '007'", ["007"]),
("find records from ds1 where id == \"007\"", ["007"]), # Double quotes
("find records from ds1 where id = 'ABC'", ["ABC"]),
("find records from ds1 where id = 007", []), # Unquoted "007" becomes num 7, should not match string "007"
("find records from ds1 where id != '007'", ["008", "ABC", "009"]),
("find records from ds1 where id != 007", ["007", "008", "ABC", "009"]), # num 7 != string "007" (coercion to string then compare)
# Numeric value comparisons (reading is number 100, 200, 150; level is float 5.5, 6.0)
("find records from ds1 where reading = 100", ["007"]),
("find records from ds1 where reading == 150", ["ABC"]),
("find records from ds1 where reading = '100'", ["007"]), # Query "100" (str) vs metadata 100 (int) -> coercion
("find records from ds1 where reading > 100", ["008", "ABC"]),
("find records from ds1 where reading >= 150", ["008", "ABC"]),
("find records from ds1 where reading < 150", ["007"]),
("find records from ds1 where reading <= 100", ["007"]),
("find records from ds1 where reading != 100", ["008", "ABC", "009"]), # "009" has reading "N/A"
("find records from ds1 where level = 5.5", ["007", "008"]),
("find records from ds1 where level > 5.5", ["ABC"]),
("find records from ds1 where level = '5.5'", ["007", "008"]), # Query "5.5" (str) vs metadata 5.5 (float)
# String value comparisons (status is string "active", "inactive")
("find records from ds1 where status = 'active'", ["007", "ABC", "009"]),
("find records from ds1 where status is 'inactive'", ["008"]), # alt syntax
("find records from ds1 where status equals \"active\"", ["007", "ABC", "009"]), # alt syntax
("find records from ds1 where status eq 'active'", ["007", "ABC", "009"]), # alt syntax
# Boolean value comparisons (is_calibrated is True, False, or missing)
# Current _parse_operator_and_value will treat True/False as unquoted strings if not careful.
# With the fix, unquoted True/False will be strings "True"/"False".
# Coercion logic in query_fn_meta: if filter_value is str, actual_value becomes str.
("find records from ds1 where is_calibrated = 'True'", ["007"]), # actual True -> "True" vs "True"
("find records from ds1 where is_calibrated = 'False'", ["008"]),# actual False -> "False" vs "False"
# ("find records from ds1 where is_calibrated = True", ["007"]), # This query would make filter_value string "True"
# ("find records from ds1 where is_calibrated = 1", []), # is_calibrated (bool) vs 1 (int) - needs specific handling or relies on coercion
# Handling of metadata key "reading" which is string "N/A" for one record
("find records from ds1 where reading = 'N/A'", ["009"]),
("find records from ds1 where reading != 'N/A'", ["007", "008", "ABC"]),
("find records from ds1 where reading > 1000", []), # "N/A" should not satisfy > numeric
("find records from ds1 where reading = 0", []), # "N/A" should not satisfy = numeric
])
def test_metadata_filtering(basic_storage, query, expected_ids):
agent = NQLAgent(basic_storage)
result = agent.process_query(query)
assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}"
assert result["count"] == len(expected_ids)
found_ids = sorted([res["metadata"]["id"] for res in result["results"]])
assert found_ids == sorted(expected_ids)
def test_filter_meta_key_not_exist(basic_storage):
agent = NQLAgent(basic_storage)
query = "find records from ds1 where non_existent_key = 'value'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 0
def test_filter_meta_partially_existing_key(basic_storage):
agent = NQLAgent(basic_storage)
# 'is_calibrated' is missing from one record ("ABC")
query = "find records from ds1 where is_calibrated = 'True'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 1
assert result["results"][0]["metadata"]["id"] == "007"
# --- Tensor Filter Tests ---
@pytest.mark.parametrize("query, expected_ids_values", [
# ds2: rec1: [0.1, 0.2], val:10; rec2: [0.3, 0.4, 0.5], val:20; rec3: [], val:0
("get records from ds2 where tensor[0] > 0.05", {"rec1": 0.1, "rec2": 0.3}),
("get records from ds2 where tensor[1] < 0.3", {"rec1": 0.2}),
("get records from ds2 where tensor[2] = 0.5", {"rec2": 0.5}),
("get records from ds2 where tensor[0] = 0.1", {"rec1": 0.1}),
# Index-less queries (default to first element)
("get records from ds2 where tensor > 0.2", {"rec2": 0.3}), # rec1 fails (0.1), rec2 matches (0.3)
("get records from ds2 where tensor = 0.1", {"rec1": 0.1}),
# Empty tensor record (rec3) should not match positive conditions, or error
("get records from ds2 where tensor[0] > 0", {"rec1": 0.1, "rec2": 0.3}), # rec3's empty tensor fails index access
("get records from ds2 where tensor > 0", {"rec1": 0.1, "rec2": 0.3}), # rec3's empty tensor fails first element access
])
def test_tensor_filtering(basic_storage, query, expected_ids_values):
agent = NQLAgent(basic_storage)
result = agent.process_query(query)
assert result["success"], f"Query '{query}' failed: {result.get('message', 'No message')}"
assert result["count"] == len(expected_ids_values)
found_ids = sorted([res["metadata"]["record_id"] for res in result["results"]])
assert found_ids == sorted(expected_ids_values.keys())
# Further checks on actual tensor values could be added if needed
def test_filter_tensor_index_out_of_bounds(basic_storage):
agent = NQLAgent(basic_storage)
query = "get records from ds2 where tensor[5] > 0"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 0
def test_filter_tensor_non_numeric_value_in_query(basic_storage):
agent = NQLAgent(basic_storage)
query = "get records from ds2 where tensor[0] > 'abc'"
result = agent.process_query(query)
assert not result["success"]
assert "Tensor value filtering currently only supports numeric comparisons" in result["message"]
# --- General Error and Edge Case Tests ---
def test_filter_meta_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("find records from missing_ds where key = 'val'")
assert not result["success"]
assert "missing_ds" in result["message"].lower()
def test_filter_tensor_query_missing_dataset(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("show data from missing_ds where tensor[0] > 0")
assert not result["success"]
assert "missing_ds" in result["message"].lower()
def test_unsupported_query_format(basic_storage):
agent = NQLAgent(basic_storage)
result = agent.process_query("this is not a valid query")
assert not result["success"]
assert "Sorry, I couldn't understand that query" in result["message"]
def test_query_with_extra_spacing(basic_storage):
agent = NQLAgent(basic_storage)
# Test a metadata filter query with extra spaces
query = "find records from ds1 where id = '007'"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 1
assert result["results"][0]["metadata"]["id"] == "007"
# Test a count query
query = "count records in ds1"
result = agent.process_query(query)
assert result["success"]
assert result["count"] == 4
|