| import pytest |
| from unittest.mock import MagicMock, patch, call |
| pytest.importorskip("psycopg2") |
| import psycopg2 |
| from uuid import uuid4, UUID |
| from datetime import datetime |
| import json |
|
|
| from tensorus.metadata.postgres_storage import PostgresMetadataStorage |
| from tensorus.metadata.schemas import ( |
| TensorDescriptor, SemanticMetadata, DataType, StorageFormat, AccessControl, CompressionInfo, |
| LineageMetadata, ComputationalMetadata, QualityMetadata, RelationalMetadata, UsageMetadata, |
| LineageSource, LineageSourceType, ParentTensorLink |
| ) |
| from tensorus.metadata.schemas_iodata import TensorusExportData, TensorusExportEntry |
|
|
| |
|
|
| @pytest.fixture |
| def mock_pool(): |
| """Mocks the psycopg2 connection pool.""" |
| mock_conn = MagicMock() |
| mock_cursor = MagicMock() |
| mock_conn.cursor.return_value.__enter__.return_value = mock_cursor |
|
|
| pool = MagicMock(spec=psycopg2.pool.SimpleConnectionPool) |
| pool.getconn.return_value = mock_conn |
| pool.putconn.return_value = None |
| return pool, mock_cursor |
|
|
| @pytest.fixture |
| def pg_storage(mock_pool): |
| """Provides a PostgresMetadataStorage instance with a mocked pool.""" |
| pool, _ = mock_pool |
| |
| with patch('psycopg2.pool.SimpleConnectionPool', return_value=pool): |
| storage = PostgresMetadataStorage(dsn="postgresql://mockuser:mockpass@mockhost/mockdb") |
| return storage |
|
|
|
|
| |
|
|
| def test_pg_add_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| now = datetime.utcnow() |
| ac = AccessControl(read=["user1"], write=["owner"]) |
| ci = CompressionInfo(algorithm="zstd", level=3) |
|
|
| td = TensorDescriptor( |
| tensor_id=td_id, dimensionality=2, shape=[10,20], data_type=DataType.FLOAT32, |
| storage_format=StorageFormat.RAW, creation_timestamp=now, last_modified_timestamp=now, |
| owner="test_owner", access_control=ac, byte_size=1600, checksum="chk123", |
| compression_info=ci, tags=["tag1", "tag2"], metadata={"key": "value"} |
| ) |
| pg_storage.add_tensor_descriptor(td) |
|
|
| expected_query_part = "INSERT INTO tensor_descriptors" |
| |
| assert mock_cursor.execute.call_count == 1 |
| args, _ = mock_cursor.execute.call_args |
| assert expected_query_part in args[0] |
|
|
| |
| |
| |
| expected_params = ( |
| td_id, 2, [10,20], 'float32', 'raw', now, now, 'test_owner', |
| ac.model_dump_json(), 1600, 'chk123', ci.model_dump_json(), |
| ['tag1', 'tag2'], json.dumps({"key": "value"}) |
| ) |
| assert args[1] == expected_params |
|
|
|
|
| def test_pg_get_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| now = datetime.utcnow() |
|
|
| |
| mock_db_row = { |
| 'tensor_id': td_id, 'dimensionality': 1, 'shape': [5], 'data_type': 'int32', |
| 'storage_format': 'numpy_npz', 'creation_timestamp': now, 'last_modified_timestamp': now, |
| 'owner': 'fetch_user', 'access_control': {"read": ["public"]}, 'byte_size': 20, |
| 'checksum': None, 'compression_info': None, 'tags': ['fetched'], 'metadata': {'source': 'db'} |
| } |
| mock_cursor.fetchone.return_value = mock_db_row |
|
|
| descriptor = pg_storage.get_tensor_descriptor(td_id) |
|
|
| assert mock_cursor.execute.call_count == 1 |
| query_args, _ = mock_cursor.execute.call_args |
| assert "SELECT * FROM tensor_descriptors WHERE tensor_id = %s;" == query_args[0] |
| assert query_args[1] == (td_id,) |
|
|
| assert descriptor is not None |
| assert descriptor.tensor_id == td_id |
| assert descriptor.owner == 'fetch_user' |
| assert descriptor.data_type == DataType.INT32 |
| assert descriptor.access_control.read == ["public"] |
| assert descriptor.tags == ['fetched'] |
|
|
| def test_pg_delete_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| mock_cursor.rowcount = 1 |
|
|
| result = pg_storage.delete_tensor_descriptor(td_id) |
|
|
| assert result is True |
| mock_cursor.execute.assert_called_once_with("DELETE FROM tensor_descriptors WHERE tensor_id = %s;", (td_id,)) |
|
|
|
|
| |
| def test_pg_add_semantic_metadata(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| |
| now = datetime.utcnow() |
| mock_cursor.fetchone.return_value = { |
| 'tensor_id': td_id, |
| 'dimensionality': 1, |
| 'shape': [1], |
| 'data_type': 'float32', |
| 'storage_format': 'raw', |
| 'creation_timestamp': now, |
| 'last_modified_timestamp': now, |
| 'owner': 'owner', |
| 'access_control': {}, |
| 'byte_size': 4, |
| 'checksum': None, |
| 'compression_info': None, |
| 'tags': [], |
| 'metadata': {} |
| } |
|
|
| sm = SemanticMetadata(tensor_id=td_id, name="purpose", description="for science") |
| pg_storage.add_semantic_metadata(sm) |
|
|
| |
| assert mock_cursor.execute.call_count == 2 |
| insert_call_args, _ = mock_cursor.execute.call_args_list[1] |
|
|
| assert "INSERT INTO semantic_metadata_entries" in insert_call_args[0] |
| assert insert_call_args[1] == (td_id, "purpose", "for science") |
|
|
|
|
| |
| def test_pg_add_jsonb_lineage_metadata(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| |
| now = datetime.utcnow() |
| mock_cursor.fetchone.return_value = { |
| 'tensor_id': td_id, |
| 'dimensionality': 1, |
| 'shape': [1], |
| 'data_type': 'float32', |
| 'storage_format': 'raw', |
| 'creation_timestamp': now, |
| 'last_modified_timestamp': now, |
| 'owner': 'owner', |
| 'access_control': {}, |
| 'byte_size': 4, |
| 'checksum': None, |
| 'compression_info': None, |
| 'tags': [], |
| 'metadata': {} |
| } |
|
|
| lm = LineageMetadata(tensor_id=td_id, version="v1.pg") |
| pg_storage.add_lineage_metadata(lm) |
|
|
| assert mock_cursor.execute.call_count == 2 |
| insert_call_args, _ = mock_cursor.execute.call_args_list[1] |
|
|
| assert "INSERT INTO lineage_metadata (tensor_id, data)" in insert_call_args[0] |
| |
| assert insert_call_args[1] == {"tensor_id": td_id, "data": lm.model_dump_json()} |
|
|
|
|
| |
| def test_pg_list_tensor_descriptors_with_filters(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
|
|
| |
| pg_storage.list_tensor_descriptors(owner="filter_owner", lineage_version="vFilter") |
|
|
| assert mock_cursor.execute.call_count == 1 |
| query_args, _ = mock_cursor.execute.call_args |
| sql_query = query_args[0] |
| params = query_args[1] |
|
|
| assert "FROM tensor_descriptors td" in sql_query |
| assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query |
| assert "td.owner = %(owner)s" in sql_query |
| assert "lm.data->>'version' = %(lineage_version)s" in sql_query |
| assert params == {"owner": "filter_owner", "lineage_version": "vFilter"} |
|
|
|
|
| |
| def test_pg_get_parent_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| mock_cursor.fetchall.return_value = [{'parent_id': str(uuid4())}, {'parent_id': str(uuid4())}] |
|
|
| pg_storage.get_parent_tensor_ids(td_id) |
|
|
| assert mock_cursor.execute.call_count == 1 |
| query_args, _ = mock_cursor.execute.call_args |
| sql_query = query_args[0] |
| params = query_args[1] |
|
|
| assert "jsonb_array_elements(lm.data->'parent_tensors')" in sql_query |
| assert "WHERE lm.tensor_id = %(tensor_id)s" in sql_query |
| assert params == {"tensor_id": td_id} |
|
|
| def test_pg_get_child_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| td_id = uuid4() |
| mock_cursor.fetchall.return_value = [{'tensor_id': uuid4()}, {'tensor_id': uuid4()}] |
|
|
| pg_storage.get_child_tensor_ids(td_id) |
|
|
| assert mock_cursor.execute.call_count == 1 |
| query_args, _ = mock_cursor.execute.call_args |
| sql_query = query_args[0] |
| params = query_args[1] |
|
|
| assert "jsonb_array_elements(lm.data->'parent_tensors') AS parent" in sql_query |
| assert "parent.value->>'tensor_id' = %(target_parent_id)s" in sql_query |
| assert params == {"target_parent_id": str(td_id)} |
|
|
|
|
| |
| def test_pg_search_tensor_descriptors_sql(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
|
|
| pg_storage.search_tensor_descriptors("test_query", ["owner", "semantic.description", "lineage.version"]) |
|
|
| assert mock_cursor.execute.call_count == 1 |
| query_args, _ = mock_cursor.execute.call_args |
| sql_query = query_args[0] |
| params = query_args[1] |
|
|
| assert "owner ILIKE %(text_query)s" in sql_query |
| assert "description ILIKE %(text_query)s" in sql_query |
| assert "lm.data->>'version' ILIKE %(text_query)s" in sql_query |
| assert "LEFT JOIN semantic_metadata_entries sm ON td.tensor_id = sm.tensor_id" in sql_query |
| assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query |
| assert params == {"text_query": "%test_query%"} |
|
|
|
|
| |
| def test_pg_aggregate_tensor_descriptors_count_sql(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| mock_cursor.fetchall.return_value = [{'data_type': 'float32', 'count': 5}] |
|
|
| with pytest.raises(ValueError): |
| pg_storage.aggregate_tensor_descriptors("data_type", "count") |
|
|
| |
| |
| |
|
|
| def test_pg_check_health_ok(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| mock_cursor.execute.return_value = None |
|
|
| is_healthy, backend = pg_storage.check_health() |
| assert is_healthy is True |
| assert backend == "postgres" |
| mock_cursor.execute.assert_called_once_with("SELECT 1;", None) |
|
|
| def test_pg_check_health_fail(pg_storage: PostgresMetadataStorage, mock_pool): |
| pool, mock_cursor = mock_pool |
| mock_cursor.execute.side_effect = psycopg2.Error("Connection failed") |
| |
| |
|
|
| is_healthy, backend = pg_storage.check_health() |
| assert is_healthy is False |
| assert backend == "postgres" |
|
|
| def test_pg_get_tensor_descriptors_count(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| mock_cursor.fetchone.return_value = {'count': 123} |
|
|
| count = pg_storage.get_tensor_descriptors_count() |
| assert count == 123 |
| mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM tensor_descriptors;", None) |
|
|
| def test_pg_get_extended_metadata_count(pg_storage: PostgresMetadataStorage, mock_pool): |
| _, mock_cursor = mock_pool |
| mock_cursor.fetchone.return_value = {'count': 42} |
|
|
| count = pg_storage.get_extended_metadata_count("LineageMetadata") |
| assert count == 42 |
| mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM lineage_metadata;", None) |
|
|
| count_sm = pg_storage.get_extended_metadata_count("SemanticMetadata") |
| assert count_sm == 42 |
| mock_cursor.execute.assert_called_with("SELECT COUNT(*) as count FROM semantic_metadata_entries;", None) |
|
|
| count_unknown = pg_storage.get_extended_metadata_count("UnknownMeta") |
| assert count_unknown == 0 |
| |
| |
| |
|
|
| |
| assert mock_cursor.execute.call_count == 2 |
|
|