File size: 7,249 Bytes
edfa748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import json
import pytest
import httpx

from tensorus import mcp_server

class DummyResponse:
    def __init__(self, data):
        self._data = data
    def json(self):
        return self._data
    def raise_for_status(self):
        pass


def make_mock_client(monkeypatch, method, url, payload, response):
    class MockAsyncClient:
        async def __aenter__(self):
            return self
        async def __aexit__(self, exc_type, exc, tb):
            pass
        async def post(self, u, json=None):
            assert method == 'post'
            assert u == url
            assert json == payload
            return DummyResponse(response)
        async def get(self, u):
            assert method == 'get'
            assert u == url
            return DummyResponse(response)
        async def put(self, u, json=None):
            assert method == 'put'
            assert u == url
            assert json == payload
            return DummyResponse(response)
        async def delete(self, u):
            assert method == 'delete'
            assert u == url
            return DummyResponse(response)
    monkeypatch.setattr(mcp_server.httpx, "AsyncClient", MockAsyncClient)


def make_error_client(monkeypatch, method):
    class ErrorAsyncClient:
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            pass

        async def post(self, u, json=None):
            assert method == "post"
            raise httpx.HTTPError("failed")

        async def get(self, u):
            assert method == "get"
            raise httpx.HTTPError("failed")

        async def put(self, u, json=None):
            assert method == "put"
            raise httpx.HTTPError("failed")

        async def delete(self, u):
            assert method == "delete"
            raise httpx.HTTPError("failed")

    monkeypatch.setattr(mcp_server.httpx, "AsyncClient", ErrorAsyncClient)


@pytest.mark.asyncio
async def test_save_tensor(monkeypatch):
    payload = {
        "shape": [2, 2],
        "dtype": "float32",
        "data": [[1, 2], [3, 4]],
        "metadata": {"a": 1},
    }
    response = {"ok": True}
    url = f"{mcp_server.API_BASE_URL}/datasets/ds1/ingest"
    make_mock_client(monkeypatch, "post", url, payload, response)
    result = await mcp_server.save_tensor.fn("ds1", (2, 2), "float32", [[1, 2], [3, 4]], {"a": 1})
    assert json.loads(result.text) == response


@pytest.mark.asyncio
async def test_get_tensor(monkeypatch):
    response = {"record_id": "abc"}
    url = f"{mcp_server.API_BASE_URL}/datasets/ds1/tensors/abc"
    make_mock_client(monkeypatch, "get", url, None, response)
    result = await mcp_server.get_tensor.fn("ds1", "abc")
    assert json.loads(result.text) == response


@pytest.mark.asyncio
async def test_execute_nql_query(monkeypatch):
    response = {"results": []}
    url = f"{mcp_server.API_BASE_URL}/query"
    make_mock_client(monkeypatch, "post", url, {"query": "count"}, response)
    result = await mcp_server.execute_nql_query.fn("count")
    assert json.loads(result.text) == response


@pytest.mark.asyncio
async def test_dataset_tools(monkeypatch):
    create_resp = {"message": "ok"}
    list_resp = {"data": ["ds1"]}
    delete_resp = {"deleted": True}

    make_mock_client(
        monkeypatch,
        "post",
        f"{mcp_server.API_BASE_URL}/datasets/create",
        {"name": "ds1"},
        create_resp,
    )
    res_create = await mcp_server.tensorus_create_dataset.fn("ds1")
    assert json.loads(res_create.text) == create_resp

    make_mock_client(
        monkeypatch,
        "get",
        f"{mcp_server.API_BASE_URL}/datasets",
        None,
        list_resp,
    )
    res_list = await mcp_server.tensorus_list_datasets.fn()
    assert json.loads(res_list.text) == list_resp

    make_mock_client(
        monkeypatch,
        "delete",
        f"{mcp_server.API_BASE_URL}/datasets/ds1",
        None,
        delete_resp,
    )
    res_delete = await mcp_server.tensorus_delete_dataset.fn("ds1")
    assert json.loads(res_delete.text) == delete_resp


@pytest.mark.asyncio
async def test_tensor_tools(monkeypatch):
    ingest_payload = {
        "shape": [1],
        "dtype": "int32",
        "data": [1],
        "metadata": None,
    }
    ingest_resp = {"record_id": "r1"}
    make_mock_client(
        monkeypatch,
        "post",
        f"{mcp_server.API_BASE_URL}/datasets/ds1/ingest",
        ingest_payload,
        ingest_resp,
    )
    res_ingest = await mcp_server.tensorus_ingest_tensor.fn("ds1", [1], "int32", [1])
    assert json.loads(res_ingest.text) == ingest_resp

    details_resp = {"record_id": "r1", "data": [1]}
    make_mock_client(
        monkeypatch,
        "get",
        f"{mcp_server.API_BASE_URL}/datasets/ds1/tensors/r1",
        None,
        details_resp,
    )
    res_details = await mcp_server.tensorus_get_tensor_details.fn("ds1", "r1")
    assert json.loads(res_details.text) == details_resp

    delete_resp = {"deleted": True}
    make_mock_client(
        monkeypatch,
        "delete",
        f"{mcp_server.API_BASE_URL}/datasets/ds1/tensors/r1",
        None,
        delete_resp,
    )
    res_delete = await mcp_server.tensorus_delete_tensor.fn("ds1", "r1")
    assert json.loads(res_delete.text) == delete_resp

    update_payload = {"new_metadata": {"x": 1}}
    update_resp = {"updated": True}
    make_mock_client(
        monkeypatch,
        "put",
        f"{mcp_server.API_BASE_URL}/datasets/ds1/tensors/r1/metadata",
        update_payload,
        update_resp,
    )
    res_update = await mcp_server.tensorus_update_tensor_metadata.fn("ds1", "r1", {"x": 1})
    assert json.loads(res_update.text) == update_resp


@pytest.mark.asyncio
@pytest.mark.parametrize(

    "func,operation,payload",

    [

        (mcp_server.tensorus_apply_unary_operation, "log", {"a": 1}),

        (mcp_server.tensorus_apply_binary_operation, "add", {"b": 2}),

        (mcp_server.tensorus_apply_list_operation, "concatenate", {"c": 3}),

    ],

)
async def test_tensor_ops(monkeypatch, func, operation, payload):
    resp = {"result": 0}
    make_mock_client(
        monkeypatch,
        "post",
        f"{mcp_server.API_BASE_URL}/ops/{operation}",
        payload,
        resp,
    )
    res = await func.fn(operation, payload)
    assert json.loads(res.text) == resp


@pytest.mark.asyncio
async def test_tensor_ops_einsum(monkeypatch):
    resp = {"result": 1}
    payload = {"equation": "i,i->", "operands": [1, 2]}
    make_mock_client(
        monkeypatch,
        "post",
        f"{mcp_server.API_BASE_URL}/ops/einsum",
        payload,
        resp,
    )
    res = await mcp_server.tensorus_apply_einsum.fn(payload)
    assert json.loads(res.text) == resp


@pytest.mark.asyncio
async def test_http_error_returns_textcontent(monkeypatch):
    make_error_client(monkeypatch, "post")
    res = await mcp_server.save_tensor.fn("ds1", [1], "int32", [1])
    assert json.loads(res.text) == {"error": "failed"}