|
|
"""Tests for I/O operations.""" |
|
|
|
|
|
import pytest |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
import tempfile |
|
|
import os |
|
|
from tensorview.io import open_any, list_variables, get_dataarray, detect_engine |
|
|
|
|
|
|
|
|
def create_sample_netcdf(): |
|
|
"""Create a sample NetCDF file for testing.""" |
|
|
|
|
|
lons = np.arange(-180, 180, 2.5) |
|
|
lats = np.arange(-90, 90, 2.5) |
|
|
times = np.arange(0, 10) |
|
|
|
|
|
lon_grid, lat_grid = np.meshgrid(lons, lats) |
|
|
|
|
|
|
|
|
temp_data = np.random.randn(len(times), len(lats), len(lons)) + 20 |
|
|
|
|
|
|
|
|
ds = xr.Dataset({ |
|
|
'temperature': (['time', 'lat', 'lon'], temp_data, { |
|
|
'units': 'degrees_C', |
|
|
'long_name': 'Temperature', |
|
|
'standard_name': 'air_temperature' |
|
|
}) |
|
|
}, coords={ |
|
|
'time': ('time', times, {'units': 'days since 2000-01-01'}), |
|
|
'lat': ('lat', lats, {'units': 'degrees_north', 'long_name': 'Latitude'}), |
|
|
'lon': ('lon', lons, {'units': 'degrees_east', 'long_name': 'Longitude'}) |
|
|
}) |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(suffix='.nc', delete=False) |
|
|
ds.to_netcdf(temp_file.name) |
|
|
temp_file.close() |
|
|
|
|
|
return temp_file.name |
|
|
|
|
|
|
|
|
def test_detect_engine(): |
|
|
"""Test engine detection.""" |
|
|
assert detect_engine('test.nc') == 'h5netcdf' |
|
|
assert detect_engine('test.grib') == 'cfgrib' |
|
|
assert detect_engine('test.zarr') == 'zarr' |
|
|
assert detect_engine('test.h5') == 'h5netcdf' |
|
|
|
|
|
|
|
|
def test_open_netcdf(): |
|
|
"""Test opening NetCDF files.""" |
|
|
nc_file = create_sample_netcdf() |
|
|
|
|
|
try: |
|
|
|
|
|
handle = open_any(nc_file) |
|
|
assert handle is not None |
|
|
assert handle.engine in ['h5netcdf', 'netcdf4'] |
|
|
|
|
|
|
|
|
variables = list_variables(handle) |
|
|
assert len(variables) == 1 |
|
|
assert variables[0].name == 'temperature' |
|
|
assert variables[0].units == 'degrees_C' |
|
|
|
|
|
|
|
|
da = get_dataarray(handle, 'temperature') |
|
|
assert da.name == 'temperature' |
|
|
assert len(da.dims) == 3 |
|
|
assert 'time' in da.dims |
|
|
assert 'lat' in da.dims |
|
|
assert 'lon' in da.dims |
|
|
|
|
|
|
|
|
handle.close() |
|
|
|
|
|
finally: |
|
|
os.unlink(nc_file) |
|
|
|
|
|
|
|
|
def test_invalid_file(): |
|
|
"""Test error handling for invalid files.""" |
|
|
with pytest.raises(RuntimeError): |
|
|
open_any('nonexistent_file.nc') |
|
|
|
|
|
|
|
|
def test_invalid_variable(): |
|
|
"""Test error handling for invalid variables.""" |
|
|
nc_file = create_sample_netcdf() |
|
|
|
|
|
try: |
|
|
handle = open_any(nc_file) |
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
get_dataarray(handle, 'nonexistent_variable') |
|
|
|
|
|
handle.close() |
|
|
|
|
|
finally: |
|
|
os.unlink(nc_file) |