File size: 13,622 Bytes
edfa748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import pytest
from unittest.mock import MagicMock, patch, call
pytest.importorskip("psycopg2")
import psycopg2
from uuid import uuid4, UUID
from datetime import datetime
import json

from tensorus.metadata.postgres_storage import PostgresMetadataStorage
from tensorus.metadata.schemas import (
    TensorDescriptor, SemanticMetadata, DataType, StorageFormat, AccessControl, CompressionInfo,
    LineageMetadata, ComputationalMetadata, QualityMetadata, RelationalMetadata, UsageMetadata,
    LineageSource, LineageSourceType, ParentTensorLink # For constructing objects
)
from tensorus.metadata.schemas_iodata import TensorusExportData, TensorusExportEntry # For type hints

# --- Mocks ---

@pytest.fixture
def mock_pool():
    """Mocks the psycopg2 connection pool."""
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor # For 'with ... as cur:'

    pool = MagicMock(spec=psycopg2.pool.SimpleConnectionPool)
    pool.getconn.return_value = mock_conn
    pool.putconn.return_value = None # Doesn't need to do anything
    return pool, mock_cursor # Return cursor for assertions

@pytest.fixture
def pg_storage(mock_pool):
    """Provides a PostgresMetadataStorage instance with a mocked pool."""
    pool, _ = mock_pool
    # Temporarily patch SimpleConnectionPool in the module where PostgresMetadataStorage will call it
    with patch('psycopg2.pool.SimpleConnectionPool', return_value=pool):
        storage = PostgresMetadataStorage(dsn="postgresql://mockuser:mockpass@mockhost/mockdb")
    return storage


# --- TensorDescriptor Method Tests (Conceptual SQL Generation) ---

def test_pg_add_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    now = datetime.utcnow()
    ac = AccessControl(read=["user1"], write=["owner"])
    ci = CompressionInfo(algorithm="zstd", level=3)

    td = TensorDescriptor(
        tensor_id=td_id, dimensionality=2, shape=[10,20], data_type=DataType.FLOAT32,
        storage_format=StorageFormat.RAW, creation_timestamp=now, last_modified_timestamp=now,
        owner="test_owner", access_control=ac, byte_size=1600, checksum="chk123",
        compression_info=ci, tags=["tag1", "tag2"], metadata={"key": "value"}
    )
    pg_storage.add_tensor_descriptor(td)

    expected_query_part = "INSERT INTO tensor_descriptors" # Check a part of the query
    # Check if execute was called, and inspect its arguments
    assert mock_cursor.execute.call_count == 1
    args, _ = mock_cursor.execute.call_args
    assert expected_query_part in args[0]

    # Check some parameters (order might be an issue if not using dict params in real code)
    # The current pg_storage.add_tensor_descriptor uses tuple params.
    # For Pydantic v2, model_dump_json() is used. For v1, .json()
    expected_params = (
        td_id, 2, [10,20], 'float32', 'raw', now, now, 'test_owner',
        ac.model_dump_json(), 1600, 'chk123', ci.model_dump_json(),
        ['tag1', 'tag2'], json.dumps({"key": "value"})
    )
    assert args[1] == expected_params


def test_pg_get_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    now = datetime.utcnow()

    # Simulate a row returned from the database
    mock_db_row = {
        'tensor_id': td_id, 'dimensionality': 1, 'shape': [5], 'data_type': 'int32',
        'storage_format': 'numpy_npz', 'creation_timestamp': now, 'last_modified_timestamp': now,
        'owner': 'fetch_user', 'access_control': {"read": ["public"]}, 'byte_size': 20,
        'checksum': None, 'compression_info': None, 'tags': ['fetched'], 'metadata': {'source': 'db'}
    }
    mock_cursor.fetchone.return_value = mock_db_row

    descriptor = pg_storage.get_tensor_descriptor(td_id)

    assert mock_cursor.execute.call_count == 1
    query_args, _ = mock_cursor.execute.call_args
    assert "SELECT * FROM tensor_descriptors WHERE tensor_id = %s;" == query_args[0]
    assert query_args[1] == (td_id,)

    assert descriptor is not None
    assert descriptor.tensor_id == td_id
    assert descriptor.owner == 'fetch_user'
    assert descriptor.data_type == DataType.INT32
    assert descriptor.access_control.read == ["public"]
    assert descriptor.tags == ['fetched']

def test_pg_delete_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    mock_cursor.rowcount = 1 # Simulate one row deleted

    result = pg_storage.delete_tensor_descriptor(td_id)

    assert result is True
    mock_cursor.execute.assert_called_once_with("DELETE FROM tensor_descriptors WHERE tensor_id = %s;", (td_id,))


# --- SemanticMetadata Method Tests ---
def test_pg_add_semantic_metadata(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    # Simulate parent TD exists (get_tensor_descriptor is called by add_semantic_metadata)
    now = datetime.utcnow()
    mock_cursor.fetchone.return_value = {
        'tensor_id': td_id,
        'dimensionality': 1,
        'shape': [1],
        'data_type': 'float32',
        'storage_format': 'raw',
        'creation_timestamp': now,
        'last_modified_timestamp': now,
        'owner': 'owner',
        'access_control': {},
        'byte_size': 4,
        'checksum': None,
        'compression_info': None,
        'tags': [],
        'metadata': {}
    }  # Minimal TD row with required fields

    sm = SemanticMetadata(tensor_id=td_id, name="purpose", description="for science")
    pg_storage.add_semantic_metadata(sm)

    # First call to get_tensor_descriptor, then to INSERT semantic
    assert mock_cursor.execute.call_count == 2
    insert_call_args, _ = mock_cursor.execute.call_args_list[1] # Second call

    assert "INSERT INTO semantic_metadata_entries" in insert_call_args[0]
    assert insert_call_args[1] == (td_id, "purpose", "for science")


# --- JSONB Extended Metadata Tests (Example with LineageMetadata) ---
def test_pg_add_jsonb_lineage_metadata(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    # Simulate parent TD exists for the _add_jsonb_metadata check
    now = datetime.utcnow()
    mock_cursor.fetchone.return_value = {
        'tensor_id': td_id,
        'dimensionality': 1,
        'shape': [1],
        'data_type': 'float32',
        'storage_format': 'raw',
        'creation_timestamp': now,
        'last_modified_timestamp': now,
        'owner': 'owner',
        'access_control': {},
        'byte_size': 4,
        'checksum': None,
        'compression_info': None,
        'tags': [],
        'metadata': {}
    }

    lm = LineageMetadata(tensor_id=td_id, version="v1.pg")
    pg_storage.add_lineage_metadata(lm)

    assert mock_cursor.execute.call_count == 2 # 1 for get_td, 1 for insert lineage
    insert_call_args, _ = mock_cursor.execute.call_args_list[1]

    assert "INSERT INTO lineage_metadata (tensor_id, data)" in insert_call_args[0]
    # Pydantic v2: lm.model_dump_json()
    assert insert_call_args[1] == {"tensor_id": td_id, "data": lm.model_dump_json()}


# --- Test list_tensor_descriptors with filters (SQL construction) ---
def test_pg_list_tensor_descriptors_with_filters(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool

    # Test with owner and lineage_version filters
    pg_storage.list_tensor_descriptors(owner="filter_owner", lineage_version="vFilter")

    assert mock_cursor.execute.call_count == 1
    query_args, _ = mock_cursor.execute.call_args
    sql_query = query_args[0]
    params = query_args[1]

    assert "FROM tensor_descriptors td" in sql_query
    assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query
    assert "td.owner = %(owner)s" in sql_query
    assert "lm.data->>'version' = %(lineage_version)s" in sql_query
    assert params == {"owner": "filter_owner", "lineage_version": "vFilter"}


# --- Test Lineage Parent/Child SQL Construction ---
def test_pg_get_parent_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    mock_cursor.fetchall.return_value = [{'parent_id': str(uuid4())}, {'parent_id': str(uuid4())}] # Simulate DB response

    pg_storage.get_parent_tensor_ids(td_id)

    assert mock_cursor.execute.call_count == 1
    query_args, _ = mock_cursor.execute.call_args
    sql_query = query_args[0]
    params = query_args[1]

    assert "jsonb_array_elements(lm.data->'parent_tensors')" in sql_query
    assert "WHERE lm.tensor_id = %(tensor_id)s" in sql_query
    assert params == {"tensor_id": td_id}

def test_pg_get_child_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    td_id = uuid4()
    mock_cursor.fetchall.return_value = [{'tensor_id': uuid4()}, {'tensor_id': uuid4()}]

    pg_storage.get_child_tensor_ids(td_id)

    assert mock_cursor.execute.call_count == 1
    query_args, _ = mock_cursor.execute.call_args
    sql_query = query_args[0]
    params = query_args[1]

    assert "jsonb_array_elements(lm.data->'parent_tensors') AS parent" in sql_query
    assert "parent.value->>'tensor_id' = %(target_parent_id)s" in sql_query
    assert params == {"target_parent_id": str(td_id)}


# --- Test Search SQL Construction (Simplified) ---
def test_pg_search_tensor_descriptors_sql(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool

    pg_storage.search_tensor_descriptors("test_query", ["owner", "semantic.description", "lineage.version"])

    assert mock_cursor.execute.call_count == 1
    query_args, _ = mock_cursor.execute.call_args
    sql_query = query_args[0]
    params = query_args[1]

    assert "owner ILIKE %(text_query)s" in sql_query
    assert "description ILIKE %(text_query)s" in sql_query  # Assuming semantic_metadata_entries table alias sm
    assert "lm.data->>'version' ILIKE %(text_query)s" in sql_query # Assuming lineage_metadata alias lm
    assert "LEFT JOIN semantic_metadata_entries sm ON td.tensor_id = sm.tensor_id" in sql_query
    assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query
    assert params == {"text_query": "%test_query%"}


# --- Test Aggregation SQL Construction (Simplified for count) ---
def test_pg_aggregate_tensor_descriptors_count_sql(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    mock_cursor.fetchall.return_value = [{'data_type': 'float32', 'count': 5}]

    with pytest.raises(ValueError):
        pg_storage.aggregate_tensor_descriptors("data_type", "count")

# Note: Full implementation of export/import for Postgres is complex and marked NotImplemented.
# Tests for those would require significant mocking or a live DB and are out of scope here.
# Similarly, health check and count methods are simple SQL and tested here conceptually.

def test_pg_check_health_ok(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    mock_cursor.execute.return_value = None # Simulate successful execution

    is_healthy, backend = pg_storage.check_health()
    assert is_healthy is True
    assert backend == "postgres"
    mock_cursor.execute.assert_called_once_with("SELECT 1;", None)

def test_pg_check_health_fail(pg_storage: PostgresMetadataStorage, mock_pool):
    pool, mock_cursor = mock_pool
    mock_cursor.execute.side_effect = psycopg2.Error("Connection failed")
    # Also need to mock getconn if the error happens there
    # pool.getconn.side_effect = psycopg2.Error("Pool failed")

    is_healthy, backend = pg_storage.check_health()
    assert is_healthy is False
    assert backend == "postgres"

def test_pg_get_tensor_descriptors_count(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    mock_cursor.fetchone.return_value = {'count': 123}

    count = pg_storage.get_tensor_descriptors_count()
    assert count == 123
    mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM tensor_descriptors;", None)

def test_pg_get_extended_metadata_count(pg_storage: PostgresMetadataStorage, mock_pool):
    _, mock_cursor = mock_pool
    mock_cursor.fetchone.return_value = {'count': 42}

    count = pg_storage.get_extended_metadata_count("LineageMetadata")
    assert count == 42
    mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM lineage_metadata;", None)

    count_sm = pg_storage.get_extended_metadata_count("SemanticMetadata") # Uses semantic_metadata_entries
    assert count_sm == 42 # Will use the same mock_fetchone for this test
    mock_cursor.execute.assert_called_with("SELECT COUNT(*) as count FROM semantic_metadata_entries;", None)

    count_unknown = pg_storage.get_extended_metadata_count("UnknownMeta")
    assert count_unknown == 0 # Should not call execute if table name not found
    # (call_count remains same as last successful call if no new execute occurs)
    # To be precise, check no *new* call with UnknownMeta's table.
    # Current logic prints warning and returns 0.

    # Ensure execute wasn't called again for "UnknownMeta" after the SemanticMetadata call
    assert mock_cursor.execute.call_count == 2  # Lineage_count and Semantic_count calls