core / tests /test_postgres_storage.py
tensorus's picture
Upload 83 files
edfa748 verified
import pytest
from unittest.mock import MagicMock, patch, call
pytest.importorskip("psycopg2")
import psycopg2
from uuid import uuid4, UUID
from datetime import datetime
import json
from tensorus.metadata.postgres_storage import PostgresMetadataStorage
from tensorus.metadata.schemas import (
TensorDescriptor, SemanticMetadata, DataType, StorageFormat, AccessControl, CompressionInfo,
LineageMetadata, ComputationalMetadata, QualityMetadata, RelationalMetadata, UsageMetadata,
LineageSource, LineageSourceType, ParentTensorLink # For constructing objects
)
from tensorus.metadata.schemas_iodata import TensorusExportData, TensorusExportEntry # For type hints
# --- Mocks ---
@pytest.fixture
def mock_pool():
"""Mocks the psycopg2 connection pool."""
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor # For 'with ... as cur:'
pool = MagicMock(spec=psycopg2.pool.SimpleConnectionPool)
pool.getconn.return_value = mock_conn
pool.putconn.return_value = None # Doesn't need to do anything
return pool, mock_cursor # Return cursor for assertions
@pytest.fixture
def pg_storage(mock_pool):
"""Provides a PostgresMetadataStorage instance with a mocked pool."""
pool, _ = mock_pool
# Temporarily patch SimpleConnectionPool in the module where PostgresMetadataStorage will call it
with patch('psycopg2.pool.SimpleConnectionPool', return_value=pool):
storage = PostgresMetadataStorage(dsn="postgresql://mockuser:mockpass@mockhost/mockdb")
return storage
# --- TensorDescriptor Method Tests (Conceptual SQL Generation) ---
def test_pg_add_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
now = datetime.utcnow()
ac = AccessControl(read=["user1"], write=["owner"])
ci = CompressionInfo(algorithm="zstd", level=3)
td = TensorDescriptor(
tensor_id=td_id, dimensionality=2, shape=[10,20], data_type=DataType.FLOAT32,
storage_format=StorageFormat.RAW, creation_timestamp=now, last_modified_timestamp=now,
owner="test_owner", access_control=ac, byte_size=1600, checksum="chk123",
compression_info=ci, tags=["tag1", "tag2"], metadata={"key": "value"}
)
pg_storage.add_tensor_descriptor(td)
expected_query_part = "INSERT INTO tensor_descriptors" # Check a part of the query
# Check if execute was called, and inspect its arguments
assert mock_cursor.execute.call_count == 1
args, _ = mock_cursor.execute.call_args
assert expected_query_part in args[0]
# Check some parameters (order might be an issue if not using dict params in real code)
# The current pg_storage.add_tensor_descriptor uses tuple params.
# For Pydantic v2, model_dump_json() is used. For v1, .json()
expected_params = (
td_id, 2, [10,20], 'float32', 'raw', now, now, 'test_owner',
ac.model_dump_json(), 1600, 'chk123', ci.model_dump_json(),
['tag1', 'tag2'], json.dumps({"key": "value"})
)
assert args[1] == expected_params
def test_pg_get_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
now = datetime.utcnow()
# Simulate a row returned from the database
mock_db_row = {
'tensor_id': td_id, 'dimensionality': 1, 'shape': [5], 'data_type': 'int32',
'storage_format': 'numpy_npz', 'creation_timestamp': now, 'last_modified_timestamp': now,
'owner': 'fetch_user', 'access_control': {"read": ["public"]}, 'byte_size': 20,
'checksum': None, 'compression_info': None, 'tags': ['fetched'], 'metadata': {'source': 'db'}
}
mock_cursor.fetchone.return_value = mock_db_row
descriptor = pg_storage.get_tensor_descriptor(td_id)
assert mock_cursor.execute.call_count == 1
query_args, _ = mock_cursor.execute.call_args
assert "SELECT * FROM tensor_descriptors WHERE tensor_id = %s;" == query_args[0]
assert query_args[1] == (td_id,)
assert descriptor is not None
assert descriptor.tensor_id == td_id
assert descriptor.owner == 'fetch_user'
assert descriptor.data_type == DataType.INT32
assert descriptor.access_control.read == ["public"]
assert descriptor.tags == ['fetched']
def test_pg_delete_tensor_descriptor(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
mock_cursor.rowcount = 1 # Simulate one row deleted
result = pg_storage.delete_tensor_descriptor(td_id)
assert result is True
mock_cursor.execute.assert_called_once_with("DELETE FROM tensor_descriptors WHERE tensor_id = %s;", (td_id,))
# --- SemanticMetadata Method Tests ---
def test_pg_add_semantic_metadata(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
# Simulate parent TD exists (get_tensor_descriptor is called by add_semantic_metadata)
now = datetime.utcnow()
mock_cursor.fetchone.return_value = {
'tensor_id': td_id,
'dimensionality': 1,
'shape': [1],
'data_type': 'float32',
'storage_format': 'raw',
'creation_timestamp': now,
'last_modified_timestamp': now,
'owner': 'owner',
'access_control': {},
'byte_size': 4,
'checksum': None,
'compression_info': None,
'tags': [],
'metadata': {}
} # Minimal TD row with required fields
sm = SemanticMetadata(tensor_id=td_id, name="purpose", description="for science")
pg_storage.add_semantic_metadata(sm)
# First call to get_tensor_descriptor, then to INSERT semantic
assert mock_cursor.execute.call_count == 2
insert_call_args, _ = mock_cursor.execute.call_args_list[1] # Second call
assert "INSERT INTO semantic_metadata_entries" in insert_call_args[0]
assert insert_call_args[1] == (td_id, "purpose", "for science")
# --- JSONB Extended Metadata Tests (Example with LineageMetadata) ---
def test_pg_add_jsonb_lineage_metadata(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
# Simulate parent TD exists for the _add_jsonb_metadata check
now = datetime.utcnow()
mock_cursor.fetchone.return_value = {
'tensor_id': td_id,
'dimensionality': 1,
'shape': [1],
'data_type': 'float32',
'storage_format': 'raw',
'creation_timestamp': now,
'last_modified_timestamp': now,
'owner': 'owner',
'access_control': {},
'byte_size': 4,
'checksum': None,
'compression_info': None,
'tags': [],
'metadata': {}
}
lm = LineageMetadata(tensor_id=td_id, version="v1.pg")
pg_storage.add_lineage_metadata(lm)
assert mock_cursor.execute.call_count == 2 # 1 for get_td, 1 for insert lineage
insert_call_args, _ = mock_cursor.execute.call_args_list[1]
assert "INSERT INTO lineage_metadata (tensor_id, data)" in insert_call_args[0]
# Pydantic v2: lm.model_dump_json()
assert insert_call_args[1] == {"tensor_id": td_id, "data": lm.model_dump_json()}
# --- Test list_tensor_descriptors with filters (SQL construction) ---
def test_pg_list_tensor_descriptors_with_filters(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
# Test with owner and lineage_version filters
pg_storage.list_tensor_descriptors(owner="filter_owner", lineage_version="vFilter")
assert mock_cursor.execute.call_count == 1
query_args, _ = mock_cursor.execute.call_args
sql_query = query_args[0]
params = query_args[1]
assert "FROM tensor_descriptors td" in sql_query
assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query
assert "td.owner = %(owner)s" in sql_query
assert "lm.data->>'version' = %(lineage_version)s" in sql_query
assert params == {"owner": "filter_owner", "lineage_version": "vFilter"}
# --- Test Lineage Parent/Child SQL Construction ---
def test_pg_get_parent_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
mock_cursor.fetchall.return_value = [{'parent_id': str(uuid4())}, {'parent_id': str(uuid4())}] # Simulate DB response
pg_storage.get_parent_tensor_ids(td_id)
assert mock_cursor.execute.call_count == 1
query_args, _ = mock_cursor.execute.call_args
sql_query = query_args[0]
params = query_args[1]
assert "jsonb_array_elements(lm.data->'parent_tensors')" in sql_query
assert "WHERE lm.tensor_id = %(tensor_id)s" in sql_query
assert params == {"tensor_id": td_id}
def test_pg_get_child_tensor_ids_sql(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
td_id = uuid4()
mock_cursor.fetchall.return_value = [{'tensor_id': uuid4()}, {'tensor_id': uuid4()}]
pg_storage.get_child_tensor_ids(td_id)
assert mock_cursor.execute.call_count == 1
query_args, _ = mock_cursor.execute.call_args
sql_query = query_args[0]
params = query_args[1]
assert "jsonb_array_elements(lm.data->'parent_tensors') AS parent" in sql_query
assert "parent.value->>'tensor_id' = %(target_parent_id)s" in sql_query
assert params == {"target_parent_id": str(td_id)}
# --- Test Search SQL Construction (Simplified) ---
def test_pg_search_tensor_descriptors_sql(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
pg_storage.search_tensor_descriptors("test_query", ["owner", "semantic.description", "lineage.version"])
assert mock_cursor.execute.call_count == 1
query_args, _ = mock_cursor.execute.call_args
sql_query = query_args[0]
params = query_args[1]
assert "owner ILIKE %(text_query)s" in sql_query
assert "description ILIKE %(text_query)s" in sql_query # Assuming semantic_metadata_entries table alias sm
assert "lm.data->>'version' ILIKE %(text_query)s" in sql_query # Assuming lineage_metadata alias lm
assert "LEFT JOIN semantic_metadata_entries sm ON td.tensor_id = sm.tensor_id" in sql_query
assert "LEFT JOIN lineage_metadata lm ON td.tensor_id = lm.tensor_id" in sql_query
assert params == {"text_query": "%test_query%"}
# --- Test Aggregation SQL Construction (Simplified for count) ---
def test_pg_aggregate_tensor_descriptors_count_sql(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
mock_cursor.fetchall.return_value = [{'data_type': 'float32', 'count': 5}]
with pytest.raises(ValueError):
pg_storage.aggregate_tensor_descriptors("data_type", "count")
# Note: Full implementation of export/import for Postgres is complex and marked NotImplemented.
# Tests for those would require significant mocking or a live DB and are out of scope here.
# Similarly, health check and count methods are simple SQL and tested here conceptually.
def test_pg_check_health_ok(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
mock_cursor.execute.return_value = None # Simulate successful execution
is_healthy, backend = pg_storage.check_health()
assert is_healthy is True
assert backend == "postgres"
mock_cursor.execute.assert_called_once_with("SELECT 1;", None)
def test_pg_check_health_fail(pg_storage: PostgresMetadataStorage, mock_pool):
pool, mock_cursor = mock_pool
mock_cursor.execute.side_effect = psycopg2.Error("Connection failed")
# Also need to mock getconn if the error happens there
# pool.getconn.side_effect = psycopg2.Error("Pool failed")
is_healthy, backend = pg_storage.check_health()
assert is_healthy is False
assert backend == "postgres"
def test_pg_get_tensor_descriptors_count(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
mock_cursor.fetchone.return_value = {'count': 123}
count = pg_storage.get_tensor_descriptors_count()
assert count == 123
mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM tensor_descriptors;", None)
def test_pg_get_extended_metadata_count(pg_storage: PostgresMetadataStorage, mock_pool):
_, mock_cursor = mock_pool
mock_cursor.fetchone.return_value = {'count': 42}
count = pg_storage.get_extended_metadata_count("LineageMetadata")
assert count == 42
mock_cursor.execute.assert_called_once_with("SELECT COUNT(*) as count FROM lineage_metadata;", None)
count_sm = pg_storage.get_extended_metadata_count("SemanticMetadata") # Uses semantic_metadata_entries
assert count_sm == 42 # Will use the same mock_fetchone for this test
mock_cursor.execute.assert_called_with("SELECT COUNT(*) as count FROM semantic_metadata_entries;", None)
count_unknown = pg_storage.get_extended_metadata_count("UnknownMeta")
assert count_unknown == 0 # Should not call execute if table name not found
# (call_count remains same as last successful call if no new execute occurs)
# To be precise, check no *new* call with UnknownMeta's table.
# Current logic prints warning and returns 0.
# Ensure execute wasn't called again for "UnknownMeta" after the SemanticMetadata call
assert mock_cursor.execute.call_count == 2 # Lineage_count and Semantic_count calls