|
|
"""Tests for plotting functionality.""" |
|
|
|
|
|
import pytest |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
import matplotlib.pyplot as plt |
|
|
from tensorview.plot import plot_1d, plot_2d, setup_matplotlib |
|
|
|
|
|
|
|
|
def create_sample_data(): |
|
|
"""Create sample data for testing.""" |
|
|
|
|
|
x = np.linspace(0, 10, 100) |
|
|
y = np.sin(x) |
|
|
da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'], |
|
|
attrs={'units': 'm/s', 'long_name': 'Sine Wave'}) |
|
|
|
|
|
|
|
|
lons = np.linspace(-10, 10, 20) |
|
|
lats = np.linspace(-10, 10, 15) |
|
|
lon_grid, lat_grid = np.meshgrid(lons, lats) |
|
|
temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1 |
|
|
|
|
|
da_2d = xr.DataArray(temp_data, |
|
|
coords={'lat': lats, 'lon': lons}, |
|
|
dims=['lat', 'lon'], |
|
|
attrs={'units': 'degrees_C', 'long_name': 'Temperature'}) |
|
|
|
|
|
return da_1d, da_2d |
|
|
|
|
|
|
|
|
def test_setup_matplotlib(): |
|
|
"""Test matplotlib setup.""" |
|
|
setup_matplotlib() |
|
|
assert plt.get_backend() == 'Agg' |
|
|
|
|
|
|
|
|
def test_plot_1d(): |
|
|
"""Test 1D plotting.""" |
|
|
da_1d, _ = create_sample_data() |
|
|
|
|
|
fig = plot_1d(da_1d) |
|
|
assert fig is not None |
|
|
assert len(fig.axes) == 1 |
|
|
|
|
|
ax = fig.axes[0] |
|
|
assert len(ax.lines) == 1 |
|
|
assert ax.get_xlabel() == 'x ()' |
|
|
assert 'Sine Wave' in ax.get_title() |
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
def test_plot_2d(): |
|
|
"""Test 2D plotting.""" |
|
|
_, da_2d = create_sample_data() |
|
|
|
|
|
|
|
|
fig = plot_2d(da_2d, kind="image") |
|
|
assert fig is not None |
|
|
assert len(fig.axes) >= 1 |
|
|
|
|
|
ax = fig.axes[0] |
|
|
assert ax.get_xlabel() == 'lon ()' |
|
|
assert ax.get_ylabel() == 'lat ()' |
|
|
assert 'Temperature' in ax.get_title() |
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
fig = plot_2d(da_2d, kind="contour") |
|
|
assert fig is not None |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
def test_plot_styling(): |
|
|
"""Test plot styling options.""" |
|
|
da_1d, da_2d = create_sample_data() |
|
|
|
|
|
|
|
|
fig = plot_1d(da_1d, color='red', linewidth=2, grid=False) |
|
|
ax = fig.axes[0] |
|
|
assert ax.lines[0].get_color() == 'red' |
|
|
assert ax.lines[0].get_linewidth() == 2 |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1) |
|
|
assert fig is not None |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
def test_auto_dimension_detection(): |
|
|
"""Test automatic dimension detection.""" |
|
|
_, da_2d = create_sample_data() |
|
|
|
|
|
|
|
|
fig = plot_2d(da_2d) |
|
|
assert fig is not None |
|
|
plt.close(fig) |