"""Tests for plotting functionality.""" import pytest import numpy as np import xarray as xr import matplotlib.pyplot as plt from tensorview.plot import plot_1d, plot_2d, setup_matplotlib def create_sample_data(): """Create sample data for testing.""" # 1D data x = np.linspace(0, 10, 100) y = np.sin(x) da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'], attrs={'units': 'm/s', 'long_name': 'Sine Wave'}) # 2D data lons = np.linspace(-10, 10, 20) lats = np.linspace(-10, 10, 15) lon_grid, lat_grid = np.meshgrid(lons, lats) temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1 da_2d = xr.DataArray(temp_data, coords={'lat': lats, 'lon': lons}, dims=['lat', 'lon'], attrs={'units': 'degrees_C', 'long_name': 'Temperature'}) return da_1d, da_2d def test_setup_matplotlib(): """Test matplotlib setup.""" setup_matplotlib() assert plt.get_backend() == 'Agg' def test_plot_1d(): """Test 1D plotting.""" da_1d, _ = create_sample_data() fig = plot_1d(da_1d) assert fig is not None assert len(fig.axes) == 1 ax = fig.axes[0] assert len(ax.lines) == 1 assert ax.get_xlabel() == 'x ()' assert 'Sine Wave' in ax.get_title() plt.close(fig) def test_plot_2d(): """Test 2D plotting.""" _, da_2d = create_sample_data() # Test image plot fig = plot_2d(da_2d, kind="image") assert fig is not None assert len(fig.axes) >= 1 # Plot axis + possibly colorbar axis ax = fig.axes[0] assert ax.get_xlabel() == 'lon ()' assert ax.get_ylabel() == 'lat ()' assert 'Temperature' in ax.get_title() plt.close(fig) # Test contour plot fig = plot_2d(da_2d, kind="contour") assert fig is not None plt.close(fig) def test_plot_styling(): """Test plot styling options.""" da_1d, da_2d = create_sample_data() # Test 1D styling fig = plot_1d(da_1d, color='red', linewidth=2, grid=False) ax = fig.axes[0] assert ax.lines[0].get_color() == 'red' assert ax.lines[0].get_linewidth() == 2 plt.close(fig) # Test 2D styling fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1) assert fig is not None plt.close(fig) def test_auto_dimension_detection(): """Test automatic dimension detection.""" _, da_2d = create_sample_data() # Should work without specifying dimensions fig = plot_2d(da_2d) assert fig is not None plt.close(fig)