"""Tests for I/O operations.""" import pytest import numpy as np import xarray as xr import tempfile import os from tensorview.io import open_any, list_variables, get_dataarray, detect_engine def create_sample_netcdf(): """Create a sample NetCDF file for testing.""" # Create sample data lons = np.arange(-180, 180, 2.5) lats = np.arange(-90, 90, 2.5) times = np.arange(0, 10) lon_grid, lat_grid = np.meshgrid(lons, lats) # Create sample temperature data temp_data = np.random.randn(len(times), len(lats), len(lons)) + 20 # Create xarray Dataset ds = xr.Dataset({ 'temperature': (['time', 'lat', 'lon'], temp_data, { 'units': 'degrees_C', 'long_name': 'Temperature', 'standard_name': 'air_temperature' }) }, coords={ 'time': ('time', times, {'units': 'days since 2000-01-01'}), 'lat': ('lat', lats, {'units': 'degrees_north', 'long_name': 'Latitude'}), 'lon': ('lon', lons, {'units': 'degrees_east', 'long_name': 'Longitude'}) }) # Save to temporary file temp_file = tempfile.NamedTemporaryFile(suffix='.nc', delete=False) ds.to_netcdf(temp_file.name) temp_file.close() return temp_file.name def test_detect_engine(): """Test engine detection.""" assert detect_engine('test.nc') == 'h5netcdf' assert detect_engine('test.grib') == 'cfgrib' assert detect_engine('test.zarr') == 'zarr' assert detect_engine('test.h5') == 'h5netcdf' def test_open_netcdf(): """Test opening NetCDF files.""" nc_file = create_sample_netcdf() try: # Test opening handle = open_any(nc_file) assert handle is not None assert handle.engine in ['h5netcdf', 'netcdf4'] # Test variable listing variables = list_variables(handle) assert len(variables) == 1 assert variables[0].name == 'temperature' assert variables[0].units == 'degrees_C' # Test getting data array da = get_dataarray(handle, 'temperature') assert da.name == 'temperature' assert len(da.dims) == 3 assert 'time' in da.dims assert 'lat' in da.dims assert 'lon' in da.dims # Clean up handle.close() finally: os.unlink(nc_file) def test_invalid_file(): """Test error handling for invalid files.""" with pytest.raises(RuntimeError): open_any('nonexistent_file.nc') def test_invalid_variable(): """Test error handling for invalid variables.""" nc_file = create_sample_netcdf() try: handle = open_any(nc_file) with pytest.raises(ValueError): get_dataarray(handle, 'nonexistent_variable') handle.close() finally: os.unlink(nc_file)