File size: 2,877 Bytes
433dab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Tests for I/O operations."""

import pytest
import numpy as np
import xarray as xr
import tempfile
import os
from tensorview.io import open_any, list_variables, get_dataarray, detect_engine


def create_sample_netcdf():
    """Create a sample NetCDF file for testing."""
    # Create sample data
    lons = np.arange(-180, 180, 2.5)
    lats = np.arange(-90, 90, 2.5)
    times = np.arange(0, 10)
    
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    
    # Create sample temperature data
    temp_data = np.random.randn(len(times), len(lats), len(lons)) + 20
    
    # Create xarray Dataset
    ds = xr.Dataset({
        'temperature': (['time', 'lat', 'lon'], temp_data, {
            'units': 'degrees_C',
            'long_name': 'Temperature',
            'standard_name': 'air_temperature'
        })
    }, coords={
        'time': ('time', times, {'units': 'days since 2000-01-01'}),
        'lat': ('lat', lats, {'units': 'degrees_north', 'long_name': 'Latitude'}),
        'lon': ('lon', lons, {'units': 'degrees_east', 'long_name': 'Longitude'})
    })
    
    # Save to temporary file
    temp_file = tempfile.NamedTemporaryFile(suffix='.nc', delete=False)
    ds.to_netcdf(temp_file.name)
    temp_file.close()
    
    return temp_file.name


def test_detect_engine():
    """Test engine detection."""
    assert detect_engine('test.nc') == 'h5netcdf'
    assert detect_engine('test.grib') == 'cfgrib'
    assert detect_engine('test.zarr') == 'zarr'
    assert detect_engine('test.h5') == 'h5netcdf'


def test_open_netcdf():
    """Test opening NetCDF files."""
    nc_file = create_sample_netcdf()
    
    try:
        # Test opening
        handle = open_any(nc_file)
        assert handle is not None
        assert handle.engine in ['h5netcdf', 'netcdf4']
        
        # Test variable listing
        variables = list_variables(handle)
        assert len(variables) == 1
        assert variables[0].name == 'temperature'
        assert variables[0].units == 'degrees_C'
        
        # Test getting data array
        da = get_dataarray(handle, 'temperature')
        assert da.name == 'temperature'
        assert len(da.dims) == 3
        assert 'time' in da.dims
        assert 'lat' in da.dims
        assert 'lon' in da.dims
        
        # Clean up
        handle.close()
        
    finally:
        os.unlink(nc_file)


def test_invalid_file():
    """Test error handling for invalid files."""
    with pytest.raises(RuntimeError):
        open_any('nonexistent_file.nc')


def test_invalid_variable():
    """Test error handling for invalid variables."""
    nc_file = create_sample_netcdf()
    
    try:
        handle = open_any(nc_file)
        
        with pytest.raises(ValueError):
            get_dataarray(handle, 'nonexistent_variable')
        
        handle.close()
        
    finally:
        os.unlink(nc_file)