File size: 2,636 Bytes
433dab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Tests for plotting functionality."""

import pytest
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tensorview.plot import plot_1d, plot_2d, setup_matplotlib


def create_sample_data():
    """Create sample data for testing."""
    # 1D data
    x = np.linspace(0, 10, 100)
    y = np.sin(x)
    da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'], 
                        attrs={'units': 'm/s', 'long_name': 'Sine Wave'})
    
    # 2D data
    lons = np.linspace(-10, 10, 20)
    lats = np.linspace(-10, 10, 15)
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1
    
    da_2d = xr.DataArray(temp_data, 
                        coords={'lat': lats, 'lon': lons}, 
                        dims=['lat', 'lon'],
                        attrs={'units': 'degrees_C', 'long_name': 'Temperature'})
    
    return da_1d, da_2d


def test_setup_matplotlib():
    """Test matplotlib setup."""
    setup_matplotlib()
    assert plt.get_backend() == 'Agg'


def test_plot_1d():
    """Test 1D plotting."""
    da_1d, _ = create_sample_data()
    
    fig = plot_1d(da_1d)
    assert fig is not None
    assert len(fig.axes) == 1
    
    ax = fig.axes[0]
    assert len(ax.lines) == 1
    assert ax.get_xlabel() == 'x ()'
    assert 'Sine Wave' in ax.get_title()
    
    plt.close(fig)


def test_plot_2d():
    """Test 2D plotting."""
    _, da_2d = create_sample_data()
    
    # Test image plot
    fig = plot_2d(da_2d, kind="image")
    assert fig is not None
    assert len(fig.axes) >= 1  # Plot axis + possibly colorbar axis
    
    ax = fig.axes[0]
    assert ax.get_xlabel() == 'lon ()'
    assert ax.get_ylabel() == 'lat ()'
    assert 'Temperature' in ax.get_title()
    
    plt.close(fig)
    
    # Test contour plot
    fig = plot_2d(da_2d, kind="contour")
    assert fig is not None
    plt.close(fig)


def test_plot_styling():
    """Test plot styling options."""
    da_1d, da_2d = create_sample_data()
    
    # Test 1D styling
    fig = plot_1d(da_1d, color='red', linewidth=2, grid=False)
    ax = fig.axes[0]
    assert ax.lines[0].get_color() == 'red'
    assert ax.lines[0].get_linewidth() == 2
    plt.close(fig)
    
    # Test 2D styling
    fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1)
    assert fig is not None
    plt.close(fig)


def test_auto_dimension_detection():
    """Test automatic dimension detection."""
    _, da_2d = create_sample_data()
    
    # Should work without specifying dimensions
    fig = plot_2d(da_2d)
    assert fig is not None
    plt.close(fig)