import os import importlib.util import pytest from unittest.mock import patch pytest.importorskip("torch") pytest.importorskip("transformers") import torch from fastapi.testclient import TestClient spec = importlib.util.spec_from_file_location( "tensorus.api_legacy", os.path.join(os.path.dirname(__file__), "..", "tensorus", "api.py"), ) api = importlib.util.module_from_spec(spec) spec.loader.exec_module(api) from tensorus.nql_agent import NQLAgent from tensorus.tensor_storage import TensorStorage @pytest.fixture def client_with_llm(monkeypatch): with patch("transformers.pipeline") as mock_root, patch("transformers.pipelines.pipeline") as mock_pipe: mock_root.return_value = mock_pipe.return_value = lambda q: [{"generated_text": "get all data from test_ds"}] storage = TensorStorage(storage_path=None) storage.create_dataset("test_ds") storage.insert("test_ds", torch.tensor([1.0]), metadata={"record_id": "r1"}) agent = NQLAgent(storage, use_llm=True) monkeypatch.setattr(api, "nql_agent_instance", agent) if hasattr(api.get_nql_agent, "_instance"): monkeypatch.setattr(api.get_nql_agent, "_instance", agent, raising=False) with TestClient(api.app) as client: yield client def test_query_endpoint_with_llm_rewrite(client_with_llm): resp = client_with_llm.post("/query", json={"query": "nonsense"}) assert resp.status_code == 200 data = resp.json() assert data["success"] is True assert data["count"] == 1