Spaces:
Running
Running
| import pytest | |
| from fastapi.testclient import TestClient | |
| from uuid import uuid4, UUID | |
| from datetime import datetime | |
| import json | |
| from tensorus.api import app | |
| from tensorus.metadata.storage import InMemoryStorage # To directly interact for setup/verification | |
| from tensorus.metadata.storage_abc import MetadataStorage | |
| from tensorus.metadata.schemas import TensorDescriptor, SemanticMetadata, LineageMetadata, DataType | |
| from tensorus.metadata.schemas_iodata import TensorusExportData, TensorusExportEntry | |
| from tensorus.api.dependencies import get_storage_instance # To override if needed, or use global | |
| from tensorus.config import settings as global_settings # For API key settings | |
| from tensorus.metadata import storage_instance as global_app_storage_instance # The actual global instance | |
| # --- Fixtures --- | |
| def client_with_clean_storage(monkeypatch): | |
| """ | |
| Provides a TestClient with a fresh InMemoryStorage instance for each test. | |
| This ensures test isolation for I/O operations. | |
| It works by ensuring the global `storage_instance` (used by `get_storage_instance` dependency) | |
| is an InMemoryStorage and is cleared. | |
| """ | |
| # Ensure global settings point to in_memory for these tests | |
| monkeypatch.setattr(global_settings, "STORAGE_BACKEND", "in_memory") | |
| # Assert that the globally configured instance is indeed InMemoryStorage | |
| # (This will be true if tensorus.metadata was imported after settings were patched, | |
| # or if the default was already in_memory and not changed by other tests' env vars) | |
| # For robust testing, one might re-import tensorus.metadata or use app.dependency_overrides | |
| if not isinstance(global_app_storage_instance, InMemoryStorage): | |
| # This indicates a test setup issue or interference if another test changed the global instance type | |
| pytest.skip("Skipping I/O tests: Requires InMemoryStorage to be the active global backend for clearing.") | |
| global_app_storage_instance.clear_all_data() # Start fresh | |
| with TestClient(app) as c: | |
| yield c | |
| global_app_storage_instance.clear_all_data() # Clean up after test | |
| def sample_td_1(client_with_clean_storage: TestClient) -> TensorDescriptor: | |
| # Add directly to the global_app_storage_instance that the TestClient's app is using | |
| td = TensorDescriptor( | |
| tensor_id=uuid4(), dimensionality=1, shape=[10], data_type=DataType.FLOAT32, | |
| owner="io_test_user", byte_size=40, tags=["export_test"] | |
| ) | |
| global_app_storage_instance.add_tensor_descriptor(td) | |
| return td | |
| def sample_td_2(client_with_clean_storage: TestClient) -> TensorDescriptor: | |
| td_id = uuid4() | |
| td = TensorDescriptor( | |
| tensor_id=td_id, dimensionality=2, shape=[3,3], data_type=DataType.INT64, | |
| owner="io_test_user2", byte_size=72 | |
| ) | |
| global_app_storage_instance.add_tensor_descriptor(td) | |
| global_app_storage_instance.add_semantic_metadata(SemanticMetadata(tensor_id=td_id, name="purpose", description="testing import/export")) | |
| global_app_storage_instance.add_lineage_metadata(LineageMetadata(tensor_id=td_id, version="v1.0-exp")) | |
| return td | |
| # --- /tensors/export Endpoint Tests --- | |
| def test_export_all_data(client_with_clean_storage: TestClient, sample_td_1: TensorDescriptor, sample_td_2: TensorDescriptor): | |
| response = client_with_clean_storage.get("/tensors/export") | |
| assert response.status_code == 200 | |
| assert "attachment; filename=" in response.headers["content-disposition"] | |
| data = response.json() # FastAPI TestClient .json() handles parsing | |
| assert data["export_format_version"] == "1.0" # Using the renamed field | |
| assert len(data["entries"]) == 2 | |
| entry_ids = {entry["tensor_descriptor"]["tensor_id"] for entry in data["entries"]} | |
| assert str(sample_td_1.tensor_id) in entry_ids | |
| assert str(sample_td_2.tensor_id) in entry_ids | |
| td2_entry = next(e for e in data["entries"] if e["tensor_descriptor"]["tensor_id"] == str(sample_td_2.tensor_id)) | |
| assert td2_entry["tensor_descriptor"]["owner"] == "io_test_user2" | |
| assert len(td2_entry["semantic_metadata"]) == 1 | |
| assert td2_entry["semantic_metadata"][0]["name"] == "purpose" | |
| assert td2_entry["lineage_metadata"]["version"] == "v1.0-exp" | |
| assert td2_entry["computational_metadata"] is None | |
| def test_export_selected_tensor_ids(client_with_clean_storage: TestClient, sample_td_1: TensorDescriptor, sample_td_2: TensorDescriptor): | |
| response = client_with_clean_storage.get(f"/tensors/export?tensor_ids={sample_td_1.tensor_id}") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert len(data["entries"]) == 1 | |
| assert data["entries"][0]["tensor_descriptor"]["tensor_id"] == str(sample_td_1.tensor_id) | |
| response_multi = client_with_clean_storage.get(f"/tensors/export?tensor_ids={sample_td_1.tensor_id},{sample_td_2.tensor_id}") | |
| assert response_multi.status_code == 200 | |
| assert len(response_multi.json()["entries"]) == 2 | |
| def test_export_non_existent_tensor_id(client_with_clean_storage: TestClient, sample_td_1: TensorDescriptor): | |
| non_existent_id = uuid4() | |
| response = client_with_clean_storage.get(f"/tensors/export?tensor_ids={non_existent_id},{sample_td_1.tensor_id}") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert len(data["entries"]) == 1 | |
| assert data["entries"][0]["tensor_descriptor"]["tensor_id"] == str(sample_td_1.tensor_id) | |
| def test_export_invalid_uuid_format(client_with_clean_storage: TestClient): | |
| response = client_with_clean_storage.get("/tensors/export?tensor_ids=not-a-uuid,another-bad-uuid") | |
| assert response.status_code == 400 | |
| assert "Invalid UUID format" in response.json()["detail"] | |
| # --- /tensors/import Endpoint Tests --- | |
| API_KEY = "test_io_key_for_real" # Different from security tests to avoid clash if run together | |
| # Apply to all tests in this module | |
| def setup_io_api_keys_module(): | |
| """Set API key settings for I/O API tests using a local MonkeyPatch.""" | |
| monkeypatch = pytest.MonkeyPatch() | |
| monkeypatch.setattr(global_settings, "VALID_API_KEYS", [API_KEY]) | |
| monkeypatch.setattr(global_settings, "API_KEY_HEADER_NAME", "X-API-KEY") | |
| from tensorus.api.security import api_key_header_auth as global_api_key_header_auth | |
| monkeypatch.setattr(global_api_key_header_auth, "name", "X-API-KEY") | |
| yield | |
| monkeypatch.undo() | |
| def test_import_data_skip_strategy(client_with_clean_storage: TestClient, sample_td_1: TensorDescriptor): | |
| export_entry_td1 = TensorusExportEntry(tensor_descriptor=sample_td_1) | |
| new_td_id = uuid4() | |
| new_td_payload_model = TensorDescriptor( | |
| tensor_id=new_td_id, dimensionality=1, shape=[5], data_type=DataType.INT16, owner="importer", byte_size=10 | |
| ) | |
| export_entry_new = TensorusExportEntry(tensor_descriptor=new_td_payload_model) | |
| import_payload = TensorusExportData(entries=[export_entry_td1, export_entry_new]) | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post("/tensors/import?conflict_strategy=skip", json=import_payload.model_dump(mode="json"), headers=headers) | |
| assert response.status_code == 200 | |
| summary = response.json() | |
| assert summary["imported"] == 1 | |
| assert summary["skipped"] == 1 | |
| assert summary["overwritten"] == 0 | |
| assert summary["failed"] == 0 | |
| assert client_with_clean_storage.get(f"/tensor_descriptors/{new_td_id}").status_code == 200 | |
| td1_after_import_resp = client_with_clean_storage.get(f"/tensor_descriptors/{sample_td_1.tensor_id}") | |
| assert td1_after_import_resp.status_code == 200 | |
| assert td1_after_import_resp.json()["owner"] == sample_td_1.owner | |
| def test_import_data_overwrite_strategy(client_with_clean_storage: TestClient, sample_td_1: TensorDescriptor): | |
| updated_td1_model = sample_td_1.model_copy(deep=True) # Pydantic v2 | |
| # updated_td1_model = sample_td_1.copy(deep=True) # Pydantic v1 | |
| updated_td1_model.owner = "overwritten_owner" | |
| updated_td1_model.tags = ["overwritten_tag"] | |
| export_entry_updated_td1 = TensorusExportEntry(tensor_descriptor=updated_td1_model) | |
| import_payload = TensorusExportData(entries=[export_entry_updated_td1]) | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post("/tensors/import?conflict_strategy=overwrite", json=import_payload.model_dump(mode="json"), headers=headers) | |
| assert response.status_code == 200 | |
| summary = response.json() | |
| # Overwritten is counted, imported refers to net new usually. | |
| # The InMemoryStorage import_data counts `overwritten +=1` and `imported` is for truly new. | |
| assert summary["imported"] == 0 | |
| assert summary["skipped"] == 0 | |
| assert summary["overwritten"] == 1 | |
| assert summary["failed"] == 0 | |
| td1_after_import_resp = client_with_clean_storage.get(f"/tensor_descriptors/{sample_td_1.tensor_id}") | |
| assert td1_after_import_resp.status_code == 200 | |
| td1_after_import_data = td1_after_import_resp.json() | |
| assert td1_after_import_data["owner"] == "overwritten_owner" | |
| assert "overwritten_tag" in td1_after_import_data["tags"] | |
| def test_import_data_with_all_metadata_types(client_with_clean_storage: TestClient): | |
| td_id = uuid4() | |
| full_entry = TensorusExportEntry( | |
| tensor_descriptor=TensorDescriptor(tensor_id=td_id, dimensionality=1, shape=[1], data_type=DataType.UINT8, owner="full_import", byte_size=1), | |
| semantic_metadata=[SemanticMetadata(tensor_id=td_id, name="sm_name", description="sm_desc")], | |
| lineage_metadata=LineageMetadata(tensor_id=td_id, version="vImport"), | |
| ) | |
| import_payload = TensorusExportData(entries=[full_entry]) | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post("/tensors/import", json=import_payload.model_dump(mode="json"), headers=headers) | |
| assert response.status_code == 200 | |
| summary = response.json() | |
| assert summary["imported"] == 1 | |
| assert client_with_clean_storage.get(f"/tensor_descriptors/{td_id}").status_code == 200 | |
| sm_response = client_with_clean_storage.get(f"/tensor_descriptors/{td_id}/semantic") | |
| assert sm_response.status_code == 200; assert len(sm_response.json()) == 1 | |
| assert sm_response.json()[0]["name"] == "sm_name" | |
| lm_response = client_with_clean_storage.get(f"/tensor_descriptors/{td_id}/lineage") | |
| assert lm_response.status_code == 200; assert lm_response.json()["version"] == "vImport" | |
| def test_import_data_invalid_payload(client_with_clean_storage: TestClient): | |
| invalid_json_payload = {"foo": "bar"} | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post("/tensors/import", json=invalid_json_payload, headers=headers) | |
| assert response.status_code == 422 | |
| def test_import_data_invalid_conflict_strategy(client_with_clean_storage: TestClient): | |
| td_id = uuid4() | |
| entry = TensorusExportEntry(tensor_descriptor=TensorDescriptor(tensor_id=td_id, dimensionality=0, shape=[], data_type=DataType.BOOLEAN, owner="x", byte_size=0)) | |
| import_payload = TensorusExportData(entries=[entry]) | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post("/tensors/import?conflict_strategy=delete_all", json=import_payload.model_dump(mode="json"), headers=headers) | |
| assert response.status_code == 422 | |
| def test_import_data_postgres_not_implemented(client_with_clean_storage: TestClient, monkeypatch): | |
| from unittest.mock import MagicMock | |
| # Simulate Postgres backend for this test | |
| mock_postgres_storage = MagicMock(spec=MetadataStorage) | |
| mock_postgres_storage.import_data.side_effect = NotImplementedError("Postgres import not done.") | |
| # Override the dependency for this specific test | |
| # This requires `app` to be accessible or to create a new app instance with this override. | |
| # `client_with_clean_storage.app` provides the app instance used by the client. | |
| original_dependency = app.dependency_overrides.get(get_storage_instance) | |
| app.dependency_overrides[get_storage_instance] = lambda: mock_postgres_storage | |
| try: | |
| td_id = uuid4() | |
| entry = TensorusExportEntry( | |
| tensor_descriptor=TensorDescriptor( | |
| tensor_id=td_id, | |
| dimensionality=0, | |
| shape=[], | |
| data_type=DataType.BOOLEAN, | |
| owner="x", | |
| byte_size=0, | |
| ) | |
| ) | |
| import_payload = TensorusExportData(entries=[entry]) | |
| headers = {"X-API-KEY": API_KEY} | |
| response = client_with_clean_storage.post( | |
| "/tensors/import", | |
| json=import_payload.model_dump(mode="json"), | |
| headers=headers, | |
| ) | |
| assert response.status_code == 501 | |
| assert "Import functionality is not implemented" in response.json()["detail"] | |
| finally: | |
| if original_dependency: | |
| app.dependency_overrides[get_storage_instance] = original_dependency | |
| else: | |
| del app.dependency_overrides[get_storage_instance] | |