Spaces:
Running
Running
| import os | |
| import importlib.util | |
| import pytest | |
| from unittest.mock import patch | |
| pytest.importorskip("torch") | |
| pytest.importorskip("transformers") | |
| import torch | |
| from fastapi.testclient import TestClient | |
| spec = importlib.util.spec_from_file_location( | |
| "tensorus.api_legacy", | |
| os.path.join(os.path.dirname(__file__), "..", "tensorus", "api.py"), | |
| ) | |
| api = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(api) | |
| from tensorus.nql_agent import NQLAgent | |
| from tensorus.tensor_storage import TensorStorage | |
| def client_with_llm(monkeypatch): | |
| with patch("transformers.pipeline") as mock_root, patch("transformers.pipelines.pipeline") as mock_pipe: | |
| mock_root.return_value = mock_pipe.return_value = lambda q: [{"generated_text": "get all data from test_ds"}] | |
| storage = TensorStorage(storage_path=None) | |
| storage.create_dataset("test_ds") | |
| storage.insert("test_ds", torch.tensor([1.0]), metadata={"record_id": "r1"}) | |
| agent = NQLAgent(storage, use_llm=True) | |
| monkeypatch.setattr(api, "nql_agent_instance", agent) | |
| if hasattr(api.get_nql_agent, "_instance"): | |
| monkeypatch.setattr(api.get_nql_agent, "_instance", agent, raising=False) | |
| with TestClient(api.app) as client: | |
| yield client | |
| def test_query_endpoint_with_llm_rewrite(client_with_llm): | |
| resp = client_with_llm.post("/query", json={"query": "nonsense"}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["success"] is True | |
| assert data["count"] == 1 | |