|
|
"""I/O operations for opening and managing datasets.""" |
|
|
|
|
|
import os |
|
|
from typing import Optional, Dict, Any, List |
|
|
from dataclasses import dataclass |
|
|
import xarray as xr |
|
|
import fsspec |
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VariableSpec: |
|
|
"""Variable specification with metadata.""" |
|
|
name: str |
|
|
shape: tuple |
|
|
dims: tuple |
|
|
dtype: str |
|
|
units: str |
|
|
long_name: str |
|
|
attrs: Dict[str, Any] |
|
|
|
|
|
|
|
|
class DatasetHandle: |
|
|
"""Handle for opened datasets.""" |
|
|
|
|
|
def __init__(self, dataset: xr.Dataset, uri: str, engine: str): |
|
|
self.dataset = dataset |
|
|
self.uri = uri |
|
|
self.engine = engine |
|
|
|
|
|
def close(self): |
|
|
"""Close the dataset.""" |
|
|
if hasattr(self.dataset, 'close'): |
|
|
self.dataset.close() |
|
|
|
|
|
|
|
|
def detect_engine(uri: str) -> str: |
|
|
"""Auto-detect the appropriate engine for a given URI/file.""" |
|
|
if uri.lower().endswith('.zarr') or 'zarr' in uri.lower(): |
|
|
return 'zarr' |
|
|
elif uri.lower().endswith('.grib') or uri.lower().endswith('.grb'): |
|
|
return 'cfgrib' |
|
|
elif any(ext in uri.lower() for ext in ['.nc', '.netcdf', '.hdf', '.h5']): |
|
|
|
|
|
try: |
|
|
import h5netcdf |
|
|
return 'h5netcdf' |
|
|
except ImportError: |
|
|
return 'netcdf4' |
|
|
else: |
|
|
|
|
|
try: |
|
|
import h5netcdf |
|
|
return 'h5netcdf' |
|
|
except ImportError: |
|
|
return 'netcdf4' |
|
|
|
|
|
|
|
|
def open_any(uri: str, engine: Optional[str] = None, chunks: str = "auto") -> DatasetHandle: |
|
|
""" |
|
|
Open a dataset from various sources (local, HTTP, S3, etc.). |
|
|
|
|
|
Args: |
|
|
uri: Path or URL to dataset |
|
|
engine: Engine to use ('h5netcdf', 'netcdf4', 'cfgrib', 'zarr') |
|
|
chunks: Chunking strategy for dask |
|
|
|
|
|
Returns: |
|
|
DatasetHandle: Handle to the opened dataset |
|
|
""" |
|
|
if engine is None: |
|
|
engine = detect_engine(uri) |
|
|
|
|
|
try: |
|
|
if engine == 'zarr': |
|
|
|
|
|
if uri.startswith('s3://'): |
|
|
import s3fs |
|
|
fs = s3fs.S3FileSystem(anon=True) |
|
|
store = s3fs.S3Map(root=uri, s3=fs, check=False) |
|
|
ds = xr.open_zarr(store, chunks=chunks) |
|
|
else: |
|
|
ds = xr.open_zarr(uri, chunks=chunks) |
|
|
elif engine == 'cfgrib': |
|
|
|
|
|
ds = xr.open_dataset(uri, engine='cfgrib', chunks=chunks) |
|
|
else: |
|
|
|
|
|
if uri.startswith(('http://', 'https://')): |
|
|
|
|
|
ds = xr.open_dataset(uri, engine=engine, chunks=chunks) |
|
|
elif uri.startswith('s3://'): |
|
|
|
|
|
import s3fs |
|
|
fs = s3fs.S3FileSystem(anon=True) |
|
|
with fs.open(uri, 'rb') as f: |
|
|
ds = xr.open_dataset(f, engine=engine, chunks=chunks) |
|
|
else: |
|
|
|
|
|
ds = xr.open_dataset(uri, engine=engine, chunks=chunks) |
|
|
|
|
|
return DatasetHandle(ds, uri, engine) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to open {uri} with engine {engine}: {str(e)}") |
|
|
|
|
|
|
|
|
def list_variables(handle: DatasetHandle) -> List[VariableSpec]: |
|
|
""" |
|
|
List all data variables in the dataset with their specifications. |
|
|
|
|
|
Args: |
|
|
handle: Dataset handle |
|
|
|
|
|
Returns: |
|
|
List of VariableSpec objects |
|
|
""" |
|
|
variables = [] |
|
|
|
|
|
for var_name, var in handle.dataset.data_vars.items(): |
|
|
|
|
|
if var_name.endswith('_bounds') or var_name in handle.dataset.coords: |
|
|
continue |
|
|
|
|
|
attrs = dict(var.attrs) |
|
|
|
|
|
spec = VariableSpec( |
|
|
name=var_name, |
|
|
shape=var.shape, |
|
|
dims=var.dims, |
|
|
dtype=str(var.dtype), |
|
|
units=attrs.get('units', ''), |
|
|
long_name=attrs.get('long_name', var_name), |
|
|
attrs=attrs |
|
|
) |
|
|
variables.append(spec) |
|
|
|
|
|
return variables |
|
|
|
|
|
|
|
|
def get_dataarray(handle: DatasetHandle, var: str) -> xr.DataArray: |
|
|
""" |
|
|
Get a specific data array from the dataset. |
|
|
|
|
|
Args: |
|
|
handle: Dataset handle |
|
|
var: Variable name |
|
|
|
|
|
Returns: |
|
|
xarray DataArray |
|
|
""" |
|
|
if var not in handle.dataset.data_vars: |
|
|
raise ValueError(f"Variable '{var}' not found in dataset") |
|
|
|
|
|
return handle.dataset[var] |
|
|
|
|
|
|
|
|
def close(handle: DatasetHandle): |
|
|
"""Close a dataset handle.""" |
|
|
handle.close() |