ptpd-calibration-studio / tests /unit /test_visualization.py
Claude
feat: Add comprehensive curve display and step wedge analysis features
2903bd0 unverified
"""
Tests for curve visualization module.
"""
import numpy as np
import pytest
from ptpd_calibration.core.models import CurveData
from ptpd_calibration.core.types import CurveType
from ptpd_calibration.curves.visualization import (
CurveVisualizer,
CurveStatistics,
CurveComparisonResult,
VisualizationConfig,
PlotStyle,
ColorScheme,
)
class TestVisualizationConfig:
"""Tests for VisualizationConfig."""
def test_default_config(self):
"""Test default configuration values."""
config = VisualizationConfig()
assert config.figure_width == 10.0
assert config.figure_height == 6.0
assert config.dpi == 100
assert config.line_width == 2.0
assert config.show_grid is True
assert config.show_legend is True
assert config.color_scheme == ColorScheme.PLATINUM
def test_custom_config(self):
"""Test custom configuration values."""
config = VisualizationConfig(
figure_width=12.0,
figure_height=8.0,
dpi=150,
color_scheme=ColorScheme.ACCESSIBLE,
show_grid=False,
)
assert config.figure_width == 12.0
assert config.figure_height == 8.0
assert config.dpi == 150
assert config.color_scheme == ColorScheme.ACCESSIBLE
assert config.show_grid is False
def test_get_color_palette_platinum(self):
"""Test platinum color palette generation."""
config = VisualizationConfig(color_scheme=ColorScheme.PLATINUM)
colors = config.get_color_palette(5)
assert len(colors) == 5
assert all(isinstance(c, str) for c in colors)
assert all(c.startswith("#") for c in colors)
def test_get_color_palette_accessible(self):
"""Test accessible color palette generation."""
config = VisualizationConfig(color_scheme=ColorScheme.ACCESSIBLE)
colors = config.get_color_palette(8)
assert len(colors) == 8
# Accessible palette should have high contrast colors
assert "#0072B2" in colors # Blue
def test_get_color_palette_custom(self):
"""Test custom color palette."""
custom_colors = ["#FF0000", "#00FF00", "#0000FF"]
config = VisualizationConfig(custom_colors=custom_colors)
colors = config.get_color_palette(3)
assert colors == custom_colors
def test_get_color_palette_wrapping(self):
"""Test color palette wrapping for large requests."""
config = VisualizationConfig(color_scheme=ColorScheme.MONOCHROME)
colors = config.get_color_palette(20)
assert len(colors) == 20
class TestCurveStatistics:
"""Tests for CurveStatistics."""
def test_to_dict(self):
"""Test statistics dictionary conversion."""
stats = CurveStatistics(
name="Test Curve",
num_points=256,
input_min=0.0,
input_max=1.0,
output_min=0.0,
output_max=1.0,
gamma=1.0,
midpoint_value=0.5,
is_monotonic=True,
max_slope=1.2,
min_slope=0.8,
average_slope=1.0,
linearity_error=0.01,
)
d = stats.to_dict()
assert d["name"] == "Test Curve"
assert d["num_points"] == 256
assert d["gamma"] == 1.0
assert d["is_monotonic"] is True
assert "input_range" in d
assert "output_range" in d
class TestCurveComparisonResult:
"""Tests for CurveComparisonResult."""
def test_to_dict(self):
"""Test comparison result dictionary conversion."""
result = CurveComparisonResult(
curve_names=["Curve A", "Curve B"],
max_difference=0.05,
average_difference=0.02,
rms_difference=0.03,
correlation=0.98,
)
d = result.to_dict()
assert d["curves_compared"] == ["Curve A", "Curve B"]
assert d["max_difference"] == 0.05
assert d["correlation"] == 0.98
class TestCurveVisualizer:
"""Tests for CurveVisualizer."""
@pytest.fixture
def sample_curve(self):
"""Create a sample curve for testing."""
inputs = list(np.linspace(0, 1, 256))
outputs = list(np.linspace(0, 1, 256) ** 0.9)
return CurveData(
name="Test Curve",
input_values=inputs,
output_values=outputs,
curve_type=CurveType.LINEAR,
)
@pytest.fixture
def linear_curve(self):
"""Create a perfectly linear curve."""
inputs = list(np.linspace(0, 1, 100))
return CurveData(
name="Linear Curve",
input_values=inputs,
output_values=inputs.copy(),
)
@pytest.fixture
def gamma_curve(self):
"""Create a gamma curve."""
inputs = list(np.linspace(0, 1, 100))
outputs = list(np.array(inputs) ** 2.2)
return CurveData(
name="Gamma 2.2 Curve",
input_values=inputs,
output_values=outputs,
)
@pytest.fixture
def visualizer(self):
"""Create a visualizer instance."""
return CurveVisualizer()
def test_init_default_config(self):
"""Test visualizer initialization with default config."""
viz = CurveVisualizer()
assert viz.config is not None
assert isinstance(viz.config, VisualizationConfig)
def test_init_custom_config(self):
"""Test visualizer initialization with custom config."""
config = VisualizationConfig(dpi=200, show_grid=False)
viz = CurveVisualizer(config)
assert viz.config.dpi == 200
assert viz.config.show_grid is False
def test_compute_statistics(self, visualizer, sample_curve):
"""Test statistics computation."""
stats = visualizer.compute_statistics(sample_curve)
assert stats.name == "Test Curve"
assert stats.num_points == 256
assert stats.input_min == pytest.approx(0.0)
assert stats.input_max == pytest.approx(1.0)
assert stats.output_min == pytest.approx(0.0, abs=0.01)
assert stats.output_max == pytest.approx(1.0, abs=0.01)
assert stats.is_monotonic is True
def test_compute_statistics_linear(self, visualizer, linear_curve):
"""Test statistics for linear curve."""
stats = visualizer.compute_statistics(linear_curve)
assert stats.gamma == pytest.approx(1.0, abs=0.1)
assert stats.midpoint_value == pytest.approx(0.5, abs=0.1)
assert stats.linearity_error < 0.01
assert stats.average_slope == pytest.approx(1.0, abs=0.1)
def test_compute_statistics_gamma(self, visualizer, gamma_curve):
"""Test statistics for gamma curve."""
stats = visualizer.compute_statistics(gamma_curve)
# Gamma 2.2 curve should have gamma close to 2.2
assert stats.gamma > 1.5
assert stats.midpoint_value < 0.5 # Darker midtones
assert stats.is_monotonic is True
def test_compare_curves_same(self, visualizer, linear_curve):
"""Test comparison of identical curves."""
result = visualizer.compare_curves([linear_curve, linear_curve])
assert result.max_difference == pytest.approx(0.0)
assert result.average_difference == pytest.approx(0.0)
assert result.correlation == pytest.approx(1.0)
def test_compare_curves_different(self, visualizer, linear_curve, gamma_curve):
"""Test comparison of different curves."""
result = visualizer.compare_curves([linear_curve, gamma_curve])
assert result.max_difference > 0
assert result.average_difference > 0
assert result.correlation < 1.0
assert result.difference_curve is not None
def test_compare_curves_insufficient(self, visualizer, linear_curve):
"""Test comparison with insufficient curves."""
with pytest.raises(ValueError, match="At least 2 curves"):
visualizer.compare_curves([linear_curve])
def test_plot_single_curve(self, visualizer, sample_curve):
"""Test single curve plotting."""
fig = visualizer.plot_single_curve(sample_curve)
assert fig is not None
# Check that figure was created
assert hasattr(fig, "axes")
def test_plot_single_curve_with_options(self, visualizer, sample_curve):
"""Test single curve plotting with options."""
fig = visualizer.plot_single_curve(
sample_curve,
title="Custom Title",
style=PlotStyle.LINE_MARKERS,
color="#FF0000",
show_stats=True,
)
assert fig is not None
def test_plot_multiple_curves(self, visualizer, linear_curve, gamma_curve):
"""Test multiple curve plotting."""
fig = visualizer.plot_multiple_curves(
[linear_curve, gamma_curve],
title="Comparison",
)
assert fig is not None
def test_plot_multiple_curves_with_difference(self, visualizer, linear_curve, gamma_curve):
"""Test multiple curve plotting with difference."""
fig = visualizer.plot_multiple_curves(
[linear_curve, gamma_curve],
show_difference=True,
)
assert fig is not None
def test_plot_multiple_curves_empty(self, visualizer):
"""Test plotting with no curves."""
with pytest.raises(ValueError, match="No curves provided"):
visualizer.plot_multiple_curves([])
def test_plot_with_statistics(self, visualizer, linear_curve, gamma_curve):
"""Test plotting with statistics panel."""
fig = visualizer.plot_with_statistics(
[linear_curve, gamma_curve],
title="Analysis",
)
assert fig is not None
def test_plot_histogram(self, visualizer, sample_curve):
"""Test histogram plotting."""
fig = visualizer.plot_histogram(sample_curve, bins=50)
assert fig is not None
def test_plot_slope_analysis(self, visualizer, sample_curve):
"""Test slope analysis plotting."""
fig = visualizer.plot_slope_analysis(sample_curve)
assert fig is not None
def test_figure_to_bytes(self, visualizer, sample_curve):
"""Test figure to bytes conversion."""
fig = visualizer.plot_single_curve(sample_curve)
bytes_data = visualizer.figure_to_bytes(fig, format="png")
assert isinstance(bytes_data, bytes)
assert len(bytes_data) > 0
# PNG magic bytes
assert bytes_data[:8] == b'\x89PNG\r\n\x1a\n'
def test_save_figure(self, visualizer, sample_curve, tmp_path):
"""Test figure saving."""
fig = visualizer.plot_single_curve(sample_curve)
output_path = tmp_path / "test_curve.png"
result_path = visualizer.save_figure(fig, output_path)
assert result_path == output_path
assert output_path.exists()
assert output_path.stat().st_size > 0
class TestPlotStyles:
"""Tests for different plot styles."""
@pytest.fixture
def sample_curve(self):
"""Create sample curve."""
inputs = list(np.linspace(0, 1, 50))
outputs = list(np.array(inputs) ** 0.9)
return CurveData(
name="Test",
input_values=inputs,
output_values=outputs,
)
@pytest.fixture
def visualizer(self):
"""Create visualizer."""
return CurveVisualizer()
@pytest.mark.parametrize("style", [
PlotStyle.LINE,
PlotStyle.LINE_MARKERS,
PlotStyle.SCATTER,
PlotStyle.AREA,
PlotStyle.STEP,
])
def test_all_plot_styles(self, visualizer, sample_curve, style):
"""Test all plot styles work without error."""
fig = visualizer.plot_single_curve(sample_curve, style=style)
assert fig is not None
class TestColorSchemes:
"""Tests for color schemes."""
@pytest.mark.parametrize("scheme", [
ColorScheme.PLATINUM,
ColorScheme.MONOCHROME,
ColorScheme.VIBRANT,
ColorScheme.PASTEL,
ColorScheme.ACCESSIBLE,
])
def test_all_color_schemes(self, scheme):
"""Test all color schemes provide valid colors."""
config = VisualizationConfig(color_scheme=scheme)
colors = config.get_color_palette(5)
assert len(colors) == 5
for color in colors:
assert isinstance(color, str)
assert color.startswith("#")
assert len(color) == 7 # #RRGGBB format