ncview / tests /test_plot.py
Nipun's picture
🌍 TensorView v1.0 - Complete NetCDF/HDF/GRIB viewer
433dab5
"""Tests for plotting functionality."""
import pytest
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tensorview.plot import plot_1d, plot_2d, setup_matplotlib
def create_sample_data():
"""Create sample data for testing."""
# 1D data
x = np.linspace(0, 10, 100)
y = np.sin(x)
da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'],
attrs={'units': 'm/s', 'long_name': 'Sine Wave'})
# 2D data
lons = np.linspace(-10, 10, 20)
lats = np.linspace(-10, 10, 15)
lon_grid, lat_grid = np.meshgrid(lons, lats)
temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1
da_2d = xr.DataArray(temp_data,
coords={'lat': lats, 'lon': lons},
dims=['lat', 'lon'],
attrs={'units': 'degrees_C', 'long_name': 'Temperature'})
return da_1d, da_2d
def test_setup_matplotlib():
"""Test matplotlib setup."""
setup_matplotlib()
assert plt.get_backend() == 'Agg'
def test_plot_1d():
"""Test 1D plotting."""
da_1d, _ = create_sample_data()
fig = plot_1d(da_1d)
assert fig is not None
assert len(fig.axes) == 1
ax = fig.axes[0]
assert len(ax.lines) == 1
assert ax.get_xlabel() == 'x ()'
assert 'Sine Wave' in ax.get_title()
plt.close(fig)
def test_plot_2d():
"""Test 2D plotting."""
_, da_2d = create_sample_data()
# Test image plot
fig = plot_2d(da_2d, kind="image")
assert fig is not None
assert len(fig.axes) >= 1 # Plot axis + possibly colorbar axis
ax = fig.axes[0]
assert ax.get_xlabel() == 'lon ()'
assert ax.get_ylabel() == 'lat ()'
assert 'Temperature' in ax.get_title()
plt.close(fig)
# Test contour plot
fig = plot_2d(da_2d, kind="contour")
assert fig is not None
plt.close(fig)
def test_plot_styling():
"""Test plot styling options."""
da_1d, da_2d = create_sample_data()
# Test 1D styling
fig = plot_1d(da_1d, color='red', linewidth=2, grid=False)
ax = fig.axes[0]
assert ax.lines[0].get_color() == 'red'
assert ax.lines[0].get_linewidth() == 2
plt.close(fig)
# Test 2D styling
fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1)
assert fig is not None
plt.close(fig)
def test_auto_dimension_detection():
"""Test automatic dimension detection."""
_, da_2d = create_sample_data()
# Should work without specifying dimensions
fig = plot_2d(da_2d)
assert fig is not None
plt.close(fig)