from __future__ import annotations import importlib.util import sys import unittest from pathlib import Path import numpy as np ROOT = Path(__file__).resolve().parents[1] def load_module(name: str, path: Path): spec = importlib.util.spec_from_file_location(name, path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load {path}") module = importlib.util.module_from_spec(spec) sys.modules[name] = module spec.loader.exec_module(module) return module class VisualizationOrientationTests(unittest.TestCase): def test_comparison_uses_the_same_image_transform_for_agent_and_reference(self) -> None: runtime = load_module("visualization_runtime_under_test", ROOT / "tasks" / "_visualization_runtime.py") image = np.arange(6).reshape(2, 3) spec = {"origin": "lower", "extent": [80.0, -80.0, -80.0, 80.0]} np.testing.assert_array_equal(runtime.comparison_image(image, spec), image) def test_uq_task_does_not_force_agent_orientation_to_match_reference(self) -> None: visualization = load_module( "eht_black_hole_uq_visualization", ROOT / "tasks" / "eht_black_hole_UQ" / "evaluation" / "visualization.py", ) output_spec = visualization.PLOT_STYLE["outputs"][0] comparison_spec = visualization.PLOT_STYLE["compare"][0] self.assertNotIn("transform", output_spec) self.assertNotIn("out_transform", comparison_spec) self.assertNotIn("ref_transform", comparison_spec) self.assertEqual(output_spec.get("origin"), comparison_spec.get("origin")) self.assertEqual(output_spec.get("extent"), comparison_spec.get("extent")) if __name__ == "__main__": unittest.main()