import pytest from unittest.mock import MagicMock, patch, call pytest.importorskip("psycopg2") import psycopg2 from uuid import uuid4, UUID from datetime import datetime import json from tensorus.metadata.postgres_storage import PostgresMetadataStorage from tensorus.metadata.schemas import ( TensorDescriptor, SemanticMetadata, DataType, StorageFormat, AccessControl, CompressionInfo, LineageMetadata, ComputationalMetadata, QualityMetadata, RelationalMetadata, UsageMetadata, LineageSource, LineageSourceType, ParentTensorLink # For constructing objects ) from tensorus.metadata.schemas_iodata import TensorusExportData, TensorusExportEntry # For type hints # --- Mocks --- @pytest.fixture def mock_pool(): """Mocks the psycopg2 connection pool.""" mock_conn = MagicMock() mock_cursor = MagicMock() mock_conn.cursor.return_value.__enter__.return_value = mock_cursor # For 'with ... as cur:' pool = MagicMock(spec=psycopg2.pool.SimpleConnectionPool) pool.getconn.return_value = mock_conn pool.putconn.return_value = None # Doesn't need to do anything return pool, mock_cursor # Return cursor for assertions @pytest.fixture def pg_storage(mock_pool): """Provides a PostgresMetadataStorage instance with a mocked pool.""" pool, _ = mock_pool # Temporarily patch SimpleConnectionPool in the module where PostgresMetadataStorage will call it with patch('psycopg2.pool.SimpleConnectionPool', return_value=pool): storage = PostgresMetadataStorage(dsn="postgresql://mockuser:mockpass@mockhost/mockdb") return storage # --- TensorDescriptor Method Tests (Conceptual SQL Generation) --- def test_pg_add_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() now = datetime.utcnow() ac = AccessControl(read=["user1"], write=["owner"]) ci = CompressionInfo(algorithm="zstd", level=3) td = TensorDescriptor( tensor_id=td_id, dimensionality=2, shape=[10,20], data_type=DataType.FLOAT32, storage_format=StorageFormat.RAW, creation_timestamp=now, last_modified_timestamp=now, owner="test_owner", access_control=ac, byte_size=1600, checksum="chk123", compression_info=ci, tags=["tag1", "tag2"], metadata={"key": "value"} ) pg_storage.add_tensor_descriptor(td) expected_query_part = "INSERT INTO tensor_descriptors" # Check a part of the query # Check if execute was called, and inspect its arguments assert mock_cursor.execute.call_count == 1 args, _ = mock_cursor.execute.call_args assert expected_query_part in args[0] # Check some parameters (order might be an issue if not using dict params in real code) # The current pg_storage.add_tensor_descriptor uses tuple params. # For Pydantic v2, model_dump_json() is used. For v1, .json() expected_params = ( td_id, 2, [10,20], 'float32', 'raw', now, now, 'test_owner', ac.model_dump_json(), 1600, 'chk123', ci.model_dump_json(), ['tag1', 'tag2'], json.dumps({"key": "value"}) ) assert args[1] == expected_params def test_pg_get_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() now = datetime.utcnow() # Simulate a row returned from the database mock_db_row = { 'tensor_id': td_id, 'dimensionality': 1, 'shape': [5], 'data_type': 'int32', 'storage_format': 'numpy_npz', 'creation_timestamp': now, 'last_modified_timestamp': now, 'owner': 'fetch_user', 'access_control': {"read": ["public"]}, 'byte_size': 20, 'checksum': None, 'compression_info': None, 'tags': ['fetched'], 'metadata': {'source': 'db'} } mock_cursor.fetchone.return_value = mock_db_row descriptor = pg_storage.get_tensor_descriptor(td_id) assert mock_cursor.execute.call_count == 1 query_args, _ = mock_cursor.execute.call_args assert "SELECT * FROM tensor_descriptors WHERE tensor_id = %s;" == query_args[0] assert query_args[1] == (td_id,) assert descriptor is not None assert descriptor.tensor_id == td_id assert descriptor.owner == 'fetch_user' assert descriptor.data_type == DataType.INT32 assert descriptor.access_control.read == ["public"] assert descriptor.tags == ['fetched'] def test_pg_delete_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() mock_cursor.rowcount = 1 # Simulate one row deleted result = pg_storage.delete_tensor_descriptor(td_id) assert result is True mock_cursor.execute.assert_called_once_with("DELETE FROM tensor_descriptors WHERE tensor_id = %s;", (td_id,)) # --- SemanticMetadata Method Tests --- def test_pg_add_semantic_metadata(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() # Simulate parent TD exists (get_tensor_descriptor is called by add_semantic_metadata) now = datetime.utcnow() mock_cursor.fetchone.return_value = { 'tensor_id': td_id, 'dimensionality': 1, 'shape': [1], 'data_type': 'float32', 'storage_format': 'raw', 'creation_timestamp': now, 'last_modified_timestamp': now, 'owner': 'owner', 'access_control': {}, 'byte_size': 4, 'checksum': None, 'compression_info': None, 'tags': [], 'metadata': {} } # Minimal TD row with required fields sm = SemanticMetadata(tensor_id=td_id, name="purpose", description="for science") pg_storage.add_semantic_metadata(sm) # First call to get_tensor_descriptor, then to INSERT semantic assert mock_cursor.execute.call_count == 2 insert_call_args, _ = mock_cursor.execute.call_args_list[1] # Second call assert "INSERT INTO semantic_metadata_entries" in insert_call_args[0] assert insert_call_args[1] == (td_id, "purpose", "for science") # --- JSONB Extended Metadata Tests (Example with LineageMetadata) --- def test_pg_add_jsonb_lineage_metadata(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() # Simulate parent TD exists for the _add_jsonb_metadata check now = datetime.utcnow() mock_cursor.fetchone.return_value = { 'tensor_id': td_id, 'dimensionality': 1, 'shape': [1], 'data_type': 'float32', 'storage_format': 'raw', 'creation_timestamp': now, 'last_modified_timestamp': now, 'owner': 'owner', 'access_control': {}, 'byte_size': 4, 'checksum': None, 'compression_info': None, 'tags': [], 'metadata': {} } lm = LineageMetadata(tensor_id=td_id, version="v1.pg") pg_storage.add_lineage_metadata(lm) assert mock_cursor.execute.call_count == 2 # 1 for get_td, 1 for insert lineage insert_call_args, _ = mock_cursor.execute.call_args_list[1] assert "INSERT INTO lineage_metadata (tensor_id, data)" in insert_call_args[0] # Pydantic v2: lm.model_dump_json() assert insert_call_args[1] == {"tensor_id": td_id, "data": lm.model_dump_json()} # --- Test list_tensor_descriptors with filters (SQL construction) --- def test_pg_list_tensor_descriptors_with_filters(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool # Test with owner and lineage_version filters pg_storage.list_tensor_descriptors(owner="filter_owner", lineage_version="vFilter") assert mock_cursor.execute.call_count == 1 query_args, _ = mock_cursor.execute.call_args sql_query = query_args[0] params = query_args[1] assert "FROM tensor_descriptors td" in sql_query assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query assert "td.owner = %(owner)s" in sql_query assert "lm.data->>'version' = %(lineage_version)s" in sql_query assert params == {"owner": "filter_owner", "lineage_version": "vFilter"} # --- Test Lineage Parent/Child SQL Construction --- def test_pg_get_parent_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() mock_cursor.fetchall.return_value = [{'parent_id': str(uuid4())}, {'parent_id': str(uuid4())}] # Simulate DB response pg_storage.get_parent_tensor_ids(td_id) assert mock_cursor.execute.call_count == 1 query_args, _ = mock_cursor.execute.call_args sql_query = query_args[0] params = query_args[1] assert "jsonb_array_elements(lm.data->'parent_tensors')" in sql_query assert "WHERE lm.tensor_id = %(tensor_id)s" in sql_query assert params == {"tensor_id": td_id} def test_pg_get_child_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool td_id = uuid4() mock_cursor.fetchall.return_value = [{'tensor_id': uuid4()}, {'tensor_id': uuid4()}] pg_storage.get_child_tensor_ids(td_id) assert mock_cursor.execute.call_count == 1 query_args, _ = mock_cursor.execute.call_args sql_query = query_args[0] params = query_args[1] assert "jsonb_array_elements(lm.data->'parent_tensors') AS parent" in sql_query assert "parent.value->>'tensor_id' = %(target_parent_id)s" in sql_query assert params == {"target_parent_id": str(td_id)} # --- Test Search SQL Construction (Simplified) --- def test_pg_search_tensor_descriptors_sql(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool pg_storage.search_tensor_descriptors("test_query", ["owner", "semantic.description", "lineage.version"]) assert mock_cursor.execute.call_count == 1 query_args, _ = mock_cursor.execute.call_args sql_query = query_args[0] params = query_args[1] assert "owner ILIKE %(text_query)s" in sql_query assert "description ILIKE %(text_query)s" in sql_query # Assuming semantic_metadata_entries table alias sm assert "lm.data->>'version' ILIKE %(text_query)s" in sql_query # Assuming lineage_metadata alias lm assert "LEFT JOIN semantic_metadata_entries sm ON td.tensor_id = sm.tensor_id" in sql_query assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query assert params == {"text_query": "%test_query%"} # --- Test Aggregation SQL Construction (Simplified for count) --- def test_pg_aggregate_tensor_descriptors_count_sql(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool mock_cursor.fetchall.return_value = [{'data_type': 'float32', 'count': 5}] with pytest.raises(ValueError): pg_storage.aggregate_tensor_descriptors("data_type", "count") # Note: Full implementation of export/import for Postgres is complex and marked NotImplemented. # Tests for those would require significant mocking or a live DB and are out of scope here. # Similarly, health check and count methods are simple SQL and tested here conceptually. def test_pg_check_health_ok(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool mock_cursor.execute.return_value = None # Simulate successful execution is_healthy, backend = pg_storage.check_health() assert is_healthy is True assert backend == "postgres" mock_cursor.execute.assert_called_once_with("SELECT 1;", None) def test_pg_check_health_fail(pg_storage: PostgresMetadataStorage, mock_pool): pool, mock_cursor = mock_pool mock_cursor.execute.side_effect = psycopg2.Error("Connection failed") # Also need to mock getconn if the error happens there # pool.getconn.side_effect = psycopg2.Error("Pool failed") is_healthy, backend = pg_storage.check_health() assert is_healthy is False assert backend == "postgres" def test_pg_get_tensor_descriptors_count(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool mock_cursor.fetchone.return_value = {'count': 123} count = pg_storage.get_tensor_descriptors_count() assert count == 123 mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM tensor_descriptors;", None) def test_pg_get_extended_metadata_count(pg_storage: PostgresMetadataStorage, mock_pool): _, mock_cursor = mock_pool mock_cursor.fetchone.return_value = {'count': 42} count = pg_storage.get_extended_metadata_count("LineageMetadata") assert count == 42 mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM lineage_metadata;", None) count_sm = pg_storage.get_extended_metadata_count("SemanticMetadata") # Uses semantic_metadata_entries assert count_sm == 42 # Will use the same mock_fetchone for this test mock_cursor.execute.assert_called_with("SELECT COUNT(*) as count FROM semantic_metadata_entries;", None) count_unknown = pg_storage.get_extended_metadata_count("UnknownMeta") assert count_unknown == 0 # Should not call execute if table name not found # (call_count remains same as last successful call if no new execute occurs) # To be precise, check no *new* call with UnknownMeta's table. # Current logic prints warning and returns 0. # Ensure execute wasn't called again for "UnknownMeta" after the SemanticMetadata call assert mock_cursor.execute.call_count == 2 # Lineage_count and Semantic_count calls