π TensorView v1.0 - Complete NetCDF/HDF/GRIB viewer
Browse filesFeatures:
- Multi-dimensional data exploration with smart slicing
- Geographic map plotting with Cartopy projections
- Automatic percentile-based color scaling
- Support for NetCDF, HDF5, GRIB, Zarr formats
- Interactive Gradio interface with dynamic controls
- File upload + direct path input modes
- Comprehensive scientific data visualization
π€ Generated with Claude Code
Co-Authored-By: Claude <noreply@anthropic.com>
- .gitignore +64 -0
- README.md +90 -6
- SPECS.md +181 -0
- app.py +265 -375
- assets/colormaps/sample.cpt +12 -0
- requirements.txt +20 -10
- tensorview/__init__.py +3 -0
- tensorview/anim.py +351 -0
- tensorview/colors.py +273 -0
- tensorview/grid.py +227 -0
- tensorview/io.py +162 -0
- tensorview/plot.py +354 -0
- tensorview/state.py +286 -0
- tensorview/utils.py +218 -0
- tests/test_io.py +102 -0
- tests/test_plot.py +100 -0
.gitignore
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.env
|
| 25 |
+
.venv
|
| 26 |
+
env/
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
env.bak/
|
| 30 |
+
venv.bak/
|
| 31 |
+
|
| 32 |
+
# Jupyter Notebook
|
| 33 |
+
.ipynb_checkpoints
|
| 34 |
+
|
| 35 |
+
# Temporary files
|
| 36 |
+
*.tmp
|
| 37 |
+
*.temp
|
| 38 |
+
*~
|
| 39 |
+
|
| 40 |
+
# OS generated files
|
| 41 |
+
.DS_Store
|
| 42 |
+
.DS_Store?
|
| 43 |
+
._*
|
| 44 |
+
.Spotlight-V100
|
| 45 |
+
.Trashes
|
| 46 |
+
ehthumbs.db
|
| 47 |
+
Thumbs.db
|
| 48 |
+
|
| 49 |
+
# Test outputs
|
| 50 |
+
test_*.png
|
| 51 |
+
test_*.nc
|
| 52 |
+
|
| 53 |
+
# Debug files
|
| 54 |
+
debug_*.py
|
| 55 |
+
app_simple.py
|
| 56 |
+
app_fixed.py
|
| 57 |
+
|
| 58 |
+
# Large data files
|
| 59 |
+
*.nc
|
| 60 |
+
*.hdf
|
| 61 |
+
*.h5
|
| 62 |
+
*.grib
|
| 63 |
+
*.grb
|
| 64 |
+
tests/data/*.nc
|
README.md
CHANGED
|
@@ -1,12 +1,96 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TensorView - NetCDF/HDF/GRIB Viewer
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# π TensorView - Interactive Geospatial Data Viewer
|
| 14 |
+
|
| 15 |
+
A powerful browser-based viewer for **NetCDF**, **HDF**, **GRIB**, and **Zarr** datasets with advanced visualization capabilities.
|
| 16 |
+
|
| 17 |
+
## π Features
|
| 18 |
+
|
| 19 |
+
- **π Multi-dimensional data exploration** - Handle complex scientific datasets with automatic slicing
|
| 20 |
+
- **πΊοΈ Geographic mapping** - Built-in map projections with coastlines and gridlines
|
| 21 |
+
- **π¨ Smart color scaling** - Automatic percentile-based color limits for optimal visualization
|
| 22 |
+
- **π Multiple data formats** - NetCDF, HDF5, GRIB, Zarr support
|
| 23 |
+
- **ποΈ Interactive controls** - Dynamic sliders for dimension exploration
|
| 24 |
+
- **π€ Dual input modes** - File upload or direct file path input
|
| 25 |
+
- **π Remote data support** - Load from URLs, OPeNDAP, THREDDS servers
|
| 26 |
+
|
| 27 |
+
## π― Quick Start
|
| 28 |
+
|
| 29 |
+
1. **Upload a file** or enter a file path
|
| 30 |
+
2. **Select a variable** from the dropdown
|
| 31 |
+
3. **Choose plot type**: 2D Image or Map (for geographic data)
|
| 32 |
+
4. **Adjust dimension sliders** to explore different time steps, pressure levels, etc.
|
| 33 |
+
5. **Create plot** and explore your data!
|
| 34 |
+
|
| 35 |
+
## π Supported Data Sources
|
| 36 |
+
|
| 37 |
+
- **NetCDF files** (.nc, .netcdf) - Climate and weather data
|
| 38 |
+
- **HDF5 files** (.h5, .hdf) - Scientific datasets
|
| 39 |
+
- **GRIB files** (.grib, .grb) - Meteorological data
|
| 40 |
+
- **Zarr stores** - Cloud-optimized arrays
|
| 41 |
+
- **Remote URLs** - HTTP/HTTPS links to data files
|
| 42 |
+
- **OPeNDAP/THREDDS** - Direct server access
|
| 43 |
+
|
| 44 |
+
## π Example Use Cases
|
| 45 |
+
|
| 46 |
+
- **Climate Data**: ERA5 reanalysis, CMIP model outputs
|
| 47 |
+
- **Weather Data**: GFS/ECMWF forecasts, radar data
|
| 48 |
+
- **Air Quality**: CAMS atmospheric composition data
|
| 49 |
+
- **Oceanography**: Sea surface temperature, currents
|
| 50 |
+
- **Satellite Data**: Remote sensing products
|
| 51 |
+
|
| 52 |
+
## π§ Technical Details
|
| 53 |
+
|
| 54 |
+
Built with:
|
| 55 |
+
- **xarray + Dask** - Efficient handling of large datasets
|
| 56 |
+
- **matplotlib + Cartopy** - High-quality plotting and maps
|
| 57 |
+
- **Gradio** - Interactive web interface
|
| 58 |
+
- **Multi-engine support** - h5netcdf, netcdf4, cfgrib, zarr
|
| 59 |
+
|
| 60 |
+
### Smart Features
|
| 61 |
+
- **Automatic color scaling** using 2nd-98th percentiles
|
| 62 |
+
- **Dimension detection** with dynamic slider generation
|
| 63 |
+
- **Geographic coordinate recognition** for map plotting
|
| 64 |
+
- **Memory-efficient** lazy loading with Dask
|
| 65 |
+
|
| 66 |
+
## π‘ Tips
|
| 67 |
+
|
| 68 |
+
- For **5D data** (like CAMS forecasts): Use sliders to select time, pressure level, etc.
|
| 69 |
+
- For **geographic data**: Choose "Map" plot type for proper projections
|
| 70 |
+
- **Large files**: The app handles big datasets efficiently with lazy loading
|
| 71 |
+
- **Color issues**: The app automatically optimizes color scaling to avoid uniform plots
|
| 72 |
+
|
| 73 |
+
## ποΈ Architecture
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
tensorview/
|
| 77 |
+
βββ io.py # Data loading (NetCDF, HDF, GRIB, Zarr)
|
| 78 |
+
βββ plot.py # Visualization (1D, 2D, maps)
|
| 79 |
+
βββ grid.py # Data operations and alignment
|
| 80 |
+
βββ colors.py # Colormap handling
|
| 81 |
+
βββ utils.py # Coordinate inference
|
| 82 |
+
βββ ...
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## π Example Datasets
|
| 86 |
+
|
| 87 |
+
The app works great with:
|
| 88 |
+
- NASA Goddard Earth Sciences Data
|
| 89 |
+
- ECMWF ERA5 reanalysis
|
| 90 |
+
- NOAA climate datasets
|
| 91 |
+
- Copernicus atmosphere monitoring (CAMS)
|
| 92 |
+
- CMIP climate model outputs
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
**π Links**: [GitHub Repository](https://github.com/user/tensorview) | [Documentation](https://docs.tensorview.io)
|
SPECS.md
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPECS.md β Panoply-Lite (Python/Gradio)
|
| 2 |
+
|
| 3 |
+
Goal: A browser-based viewer for **netCDF/HDF/GRIB/Zarr** datasets with an **xarray+Dask** backend, **Cartopy** maps, multi-dimensional slicing, animations, custom color tables (CPT/ACT/RGB), map projections, **OPeNDAP/THREDDS/S3/Zarr** support, and exports (PNG/SVG/PDF/MP4). Think βPanoply in the browser.β
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 0) TL;DR Build Contract
|
| 8 |
+
|
| 9 |
+
- **Deliverable**: `app.py` + `panlite/` package + tests + pinned `pyproject.toml` (or `requirements.txt`) + minimal sample assets.
|
| 10 |
+
- **Runtime**: Python 3.11; CPU-only acceptable; FFmpeg required for MP4.
|
| 11 |
+
- **UI**: Gradio Blocks (single page).
|
| 12 |
+
- **Perf**: Lazy I/O with Dask; responsive for slices β€ ~5e6 elements.
|
| 13 |
+
- **Outcomes (Definition of Done)**: Load ERA5/CMIP-like datasets; produce a global map, a timeβlatitude section, a 1D line plot, an animation (MP4), and an AβB difference contour; export PNG/SVG/PDF; save/load view state JSON.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 1) Scope
|
| 18 |
+
|
| 19 |
+
### 1.1 MVP Features
|
| 20 |
+
- **Open sources**: local files; `https://` (incl. **OPeNDAP/THREDDS**); `s3://` (anonymous or creds); `zarr://` (local/remote).
|
| 21 |
+
- **Formats/engines**:
|
| 22 |
+
- netCDF/HDF5 β `xarray.open_dataset(..., engine="h5netcdf" | "netcdf4")`
|
| 23 |
+
- GRIB β `cfgrib` (requires ecCodes)
|
| 24 |
+
- Zarr β `xarray.open_zarr(...)`
|
| 25 |
+
- **Discovery**: list variables, dims, shapes, dtypes, attrs; CF axis inference (lat, lon, time, level).
|
| 26 |
+
- **Slicing**: choose X/Y axes; set index/value selectors for remaining dims (nearest when numeric); time selectors.
|
| 27 |
+
- **Plots**:
|
| 28 |
+
- 1D: line (any single dim)
|
| 29 |
+
- 2D: image/contour (any two dims)
|
| 30 |
+
- Map: lonβlat georeferenced 2D with Cartopy (projections, coastlines, gridlines)
|
| 31 |
+
- Sections: timeβlat, timeβlev, lonβlev (pick dims)
|
| 32 |
+
- **Combine**: sum/avg/diff of two variables on a common grid (with reindex/broadcast; error out if CRS differs).
|
| 33 |
+
- **Colors**: matplotlib colormaps + load **CPT/ACT/RGB** tables.
|
| 34 |
+
- **Overlays**: coastlines, borders; optional land/ocean masks.
|
| 35 |
+
- **Export**: PNG/SVG/PDF; **MP4** (or frames) for animations.
|
| 36 |
+
- **State**: save/load a JSON βview state.β
|
| 37 |
+
|
| 38 |
+
### 1.2 v1.1 Stretch (optional)
|
| 39 |
+
- **KMZ** export (ground overlay tiles).
|
| 40 |
+
- **GeoTIFF** export for georeferenced 2D arrays.
|
| 41 |
+
- **Trajectory** plots (CF trajectories).
|
| 42 |
+
- **Zonal mean** helper (lat/lon aggregation).
|
| 43 |
+
- **xESMF** regridding for combine.
|
| 44 |
+
|
| 45 |
+
### 1.3 Non-Goals (v1)
|
| 46 |
+
3D volume rendering; advanced reprojection pipelines; interactive WebGL.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 2) Architecture & Layout
|
| 51 |
+
|
| 52 |
+
panoply-lite/
|
| 53 |
+
app.py
|
| 54 |
+
panlite/
|
| 55 |
+
__init__.py
|
| 56 |
+
io.py # open/close; engine select; cache; variable listing
|
| 57 |
+
grid.py # alignment, simple reindex/broadcast; combine ops; sections
|
| 58 |
+
plot.py # 1D/2D/map plotting; exports
|
| 59 |
+
anim.py # animate over a dim -> MP4/frames
|
| 60 |
+
colors.py # CPT/ACT/RGB loaders -> matplotlib Colormap
|
| 61 |
+
state.py # view-state (serialize/deserialize; schema validate)
|
| 62 |
+
utils.py # CF axis guess; CRS helpers; small helpers
|
| 63 |
+
assets/
|
| 64 |
+
colormaps/ # sample .cpt/.act/.rgb
|
| 65 |
+
tests/
|
| 66 |
+
test_io.py
|
| 67 |
+
test_plot.py
|
| 68 |
+
test_anim.py
|
| 69 |
+
data/ # tiny sample datasets (β€5 MB)
|
| 70 |
+
pyproject.toml (or requirements.txt)
|
| 71 |
+
README.md
|
| 72 |
+
SPECS.md
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## 3) Dependencies (pin reasonably)
|
| 77 |
+
|
| 78 |
+
- Core: xarray, dask[complete], numpy, pandas, fsspec, s3fs, zarr, h5netcdf, cfgrib, eccodes
|
| 79 |
+
- Geo: cartopy, pyproj, rioxarray
|
| 80 |
+
- Viz: matplotlib
|
| 81 |
+
- Web UI: gradio>=4
|
| 82 |
+
- Misc: pydantic, orjson, pytest, ruff
|
| 83 |
+
- System: ffmpeg in PATH for MP4
|
| 84 |
+
|
| 85 |
+
requirements.txt example:
|
| 86 |
+
|
| 87 |
+
xarray
|
| 88 |
+
dask[complete]
|
| 89 |
+
fsspec
|
| 90 |
+
s3fs
|
| 91 |
+
zarr
|
| 92 |
+
h5netcdf
|
| 93 |
+
cfgrib
|
| 94 |
+
eccodes
|
| 95 |
+
rioxarray
|
| 96 |
+
cartopy
|
| 97 |
+
pyproj
|
| 98 |
+
matplotlib
|
| 99 |
+
gradio>=4
|
| 100 |
+
numpy
|
| 101 |
+
pandas
|
| 102 |
+
pydantic
|
| 103 |
+
orjson
|
| 104 |
+
pytest
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## 4) API Surface
|
| 109 |
+
|
| 110 |
+
### io.py
|
| 111 |
+
- open_any(uri, engine=None, chunks="auto") -> DatasetHandle
|
| 112 |
+
- list_variables(handle) -> list[VariableSpec]
|
| 113 |
+
- get_dataarray(handle, var) -> xr.DataArray
|
| 114 |
+
- close(handle)
|
| 115 |
+
|
| 116 |
+
### grid.py
|
| 117 |
+
- align_for_combine(a, b, method="reindex")
|
| 118 |
+
- combine(a, b, op="sum"|"avg"|"diff")
|
| 119 |
+
- section(da, along: str, fixed: dict)
|
| 120 |
+
|
| 121 |
+
### plot.py
|
| 122 |
+
- plot_1d(da, **style) -> Figure
|
| 123 |
+
- plot_2d(da, kind="image"|"contour", **style) -> Figure
|
| 124 |
+
- plot_map(da, proj, **style) -> Figure
|
| 125 |
+
- export_fig(fig, fmt="png"|"svg"|"pdf", dpi=150, out_path=None)
|
| 126 |
+
|
| 127 |
+
### anim.py
|
| 128 |
+
- animate_over_dim(da, dim: str, fps=10, out="anim.mp4") -> str
|
| 129 |
+
|
| 130 |
+
### colors.py
|
| 131 |
+
- load_cpt(path) -> Colormap
|
| 132 |
+
- load_act(path) -> Colormap
|
| 133 |
+
- load_rgb(path) -> Colormap
|
| 134 |
+
|
| 135 |
+
### state.py
|
| 136 |
+
- dump_state(dict) -> str (JSON)
|
| 137 |
+
- load_state(str) -> dict
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## 5) UI Spec (Gradio Blocks)
|
| 142 |
+
|
| 143 |
+
Sidebar:
|
| 144 |
+
- File open (local, URL, S3)
|
| 145 |
+
- Dataset vars: pick A, optional B; operation (sum/avg/diff)
|
| 146 |
+
- Axes: X, Y; slice others (sliders, dropdowns)
|
| 147 |
+
- Plot type: 1D, 2D image, 2D contour, Map
|
| 148 |
+
- Projection: PlateCarree, Robinson, etc
|
| 149 |
+
- Colormap: dropdown + upload CPT/ACT/RGB
|
| 150 |
+
- Colorbar options; vmin/vmax; levels
|
| 151 |
+
- Animation: dim=[time|lev], FPS, MP4 export
|
| 152 |
+
- Export: PNG/SVG/PDF
|
| 153 |
+
- Save/load view state JSON
|
| 154 |
+
|
| 155 |
+
Main Panel:
|
| 156 |
+
- Figure canvas (matplotlib)
|
| 157 |
+
- Metadata panel: units, attrs, dims
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## 6) Acceptance Criteria
|
| 162 |
+
|
| 163 |
+
- Open ERA5 (Zarr/OPeNDAP) dataset.
|
| 164 |
+
- Plot: (a) global 2D map with Robinson projection; (b) timeβlat section; (c) 1D time series.
|
| 165 |
+
- Perform AβB difference contour plot.
|
| 166 |
+
- Load custom CPT colormap and export as SVG/PDF.
|
| 167 |
+
- Animate over time dimension (24 frames), export MP4.
|
| 168 |
+
- Save view state, reload reproduces identical figure.
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## 7) Testing Checklist
|
| 173 |
+
|
| 174 |
+
- Local netCDF, remote OPeNDAP, public S3, Zarr open successfully.
|
| 175 |
+
- Variable discovery works; CF axis inference correct.
|
| 176 |
+
- 1D/2D/Map plotting functional.
|
| 177 |
+
- Combine AΒ±B correct (aligned grid).
|
| 178 |
+
- Custom CPT colormap applied correctly.
|
| 179 |
+
- Export PNG/SVG/PDF correct dimensions/DPI.
|
| 180 |
+
- Animation over time produces correct frame count, valid MP4.
|
| 181 |
+
- Large datasets responsive due to Dask.
|
app.py
CHANGED
|
@@ -1,444 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import xarray as xr
|
| 3 |
-
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
-
import
|
| 12 |
-
from
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
try:
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
}
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
'min': float(ds[var].min().values) if ds[var].size > 0 else None,
|
| 39 |
-
'max': float(ds[var].max().values) if ds[var].size > 0 else None,
|
| 40 |
-
'mean': float(ds[var].mean().values) if ds[var].size > 0 else None
|
| 41 |
-
}
|
| 42 |
-
info['data_vars_info'][var] = var_info
|
| 43 |
-
|
| 44 |
-
# Generate summary text
|
| 45 |
-
summary = f"""
|
| 46 |
-
## Dataset Overview
|
| 47 |
-
- **Dimensions**: {len(ds.dims)} ({', '.join([f"{k}: {v}" for k, v in ds.dims.items()])})
|
| 48 |
-
- **Variables**: {len(ds.data_vars)} data variables, {len(ds.coords)} coordinates
|
| 49 |
-
- **Global Attributes**: {len(ds.attrs)} attributes
|
| 50 |
|
| 51 |
-
### Variables:
|
| 52 |
"""
|
| 53 |
-
for
|
| 54 |
-
summary += f"- **{
|
| 55 |
-
if var_info['min'] is not None:
|
| 56 |
-
summary += f" [{var_info['min']:.2f} to {var_info['max']:.2f}]"
|
| 57 |
-
summary += "\n"
|
| 58 |
|
| 59 |
-
|
| 60 |
-
return summary, info
|
| 61 |
|
| 62 |
except Exception as e:
|
| 63 |
-
return f"Error
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
|
| 67 |
-
"""Create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
raise ValueError(f"Variable '{variable}' not found in dataset")
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
spatial_dims = []
|
| 80 |
-
for dim in data_var.dims:
|
| 81 |
-
if any(name in dim.lower() for name in ['lat', 'lon', 'x', 'y']):
|
| 82 |
-
spatial_dims.append(dim)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Select subset based on other dimensions
|
| 89 |
-
data_subset = data_var
|
| 90 |
-
for dim in data_var.dims:
|
| 91 |
-
if dim not in [dim1, dim2]:
|
| 92 |
-
if 'time' in dim.lower():
|
| 93 |
-
data_subset = data_subset.isel({dim: min(time_idx, data_var.sizes[dim]-1)})
|
| 94 |
-
elif any(name in dim.lower() for name in ['level', 'depth', 'height']):
|
| 95 |
-
data_subset = data_subset.isel({dim: min(level_idx, data_var.sizes[dim]-1)})
|
| 96 |
-
else:
|
| 97 |
-
data_subset = data_subset.isel({dim: 0})
|
| 98 |
else:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
if len(dims) >= 2:
|
| 102 |
-
data_subset = data_var.isel({dim: 0 for dim in dims[2:]})
|
| 103 |
-
else:
|
| 104 |
-
data_subset = data_var
|
| 105 |
-
dim1, dim2 = dims[:2]
|
| 106 |
-
else:
|
| 107 |
-
raise ValueError("Data must have at least 2 dimensions for 2D plotting")
|
| 108 |
-
|
| 109 |
-
# Create the plot
|
| 110 |
-
fig = go.Figure(data=go.Heatmap(
|
| 111 |
-
z=data_subset.values,
|
| 112 |
-
x=data_subset.coords[dim2].values if dim2 in data_subset.coords else None,
|
| 113 |
-
y=data_subset.coords[dim1].values if dim1 in data_subset.coords else None,
|
| 114 |
-
colorscale=colormap,
|
| 115 |
-
colorbar=dict(title=data_var.attrs.get('units', 'Value'))
|
| 116 |
-
))
|
| 117 |
-
|
| 118 |
-
fig.update_layout(
|
| 119 |
-
title=f"{variable} - {data_var.attrs.get('long_name', variable)}",
|
| 120 |
-
xaxis_title=dim2,
|
| 121 |
-
yaxis_title=dim1,
|
| 122 |
-
height=600,
|
| 123 |
-
width=800
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
ds.close()
|
| 127 |
-
return fig
|
| 128 |
-
|
| 129 |
-
except Exception as e:
|
| 130 |
-
# Return empty figure with error message
|
| 131 |
-
fig = go.Figure()
|
| 132 |
-
fig.add_annotation(
|
| 133 |
-
text=f"Error creating plot: {str(e)}",
|
| 134 |
-
x=0.5, y=0.5,
|
| 135 |
-
xref="paper", yref="paper",
|
| 136 |
-
showarrow=False,
|
| 137 |
-
font=dict(size=16, color="red")
|
| 138 |
-
)
|
| 139 |
-
return fig
|
| 140 |
-
|
| 141 |
-
def create_time_series(file_path: str, variable: str, method: str = "mean") -> go.Figure:
|
| 142 |
-
"""Create time series plot by aggregating spatial dimensions."""
|
| 143 |
-
try:
|
| 144 |
-
ds = xr.open_dataset(file_path)
|
| 145 |
-
|
| 146 |
-
if variable not in ds.data_vars:
|
| 147 |
-
raise ValueError(f"Variable '{variable}' not found in dataset")
|
| 148 |
-
|
| 149 |
-
data_var = ds[variable]
|
| 150 |
-
|
| 151 |
-
# Find time dimension
|
| 152 |
-
time_dim = None
|
| 153 |
-
for dim in data_var.dims:
|
| 154 |
-
if 'time' in dim.lower():
|
| 155 |
-
time_dim = dim
|
| 156 |
-
break
|
| 157 |
-
|
| 158 |
-
if time_dim is None:
|
| 159 |
-
raise ValueError("No time dimension found in the data")
|
| 160 |
-
|
| 161 |
-
# Aggregate spatial dimensions
|
| 162 |
-
spatial_dims = [dim for dim in data_var.dims if dim != time_dim]
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
time_series = data_var.max(dim=spatial_dims)
|
| 168 |
-
elif method == "min":
|
| 169 |
-
time_series = data_var.min(dim=spatial_dims)
|
| 170 |
else:
|
| 171 |
-
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
mode='lines+markers',
|
| 177 |
-
name=f"{method.title()} {variable}"
|
| 178 |
-
))
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
except Exception as e:
|
| 191 |
-
fig = go.Figure()
|
| 192 |
-
fig.add_annotation(
|
| 193 |
-
text=f"Error creating time series: {str(e)}",
|
| 194 |
-
x=0.5, y=0.5,
|
| 195 |
-
xref="paper", yref="paper",
|
| 196 |
-
showarrow=False,
|
| 197 |
-
font=dict(size=16, color="red")
|
| 198 |
-
)
|
| 199 |
-
return fig
|
| 200 |
-
|
| 201 |
-
def create_vertical_profile(file_path: str, variable: str, time_idx: int = 0) -> go.Figure:
|
| 202 |
-
"""Create vertical profile plot."""
|
| 203 |
-
try:
|
| 204 |
-
ds = xr.open_dataset(file_path)
|
| 205 |
-
|
| 206 |
-
if variable not in ds.data_vars:
|
| 207 |
-
raise ValueError(f"Variable '{variable}' not found in dataset")
|
| 208 |
-
|
| 209 |
-
data_var = ds[variable]
|
| 210 |
-
|
| 211 |
-
# Find vertical dimension
|
| 212 |
-
vertical_dim = None
|
| 213 |
-
for dim in data_var.dims:
|
| 214 |
-
if any(name in dim.lower() for name in ['level', 'depth', 'height', 'pressure']):
|
| 215 |
-
vertical_dim = dim
|
| 216 |
-
break
|
| 217 |
-
|
| 218 |
-
if vertical_dim is None:
|
| 219 |
-
raise ValueError("No vertical dimension found in the data")
|
| 220 |
-
|
| 221 |
-
# Average over horizontal dimensions, select time
|
| 222 |
-
dims_to_avg = []
|
| 223 |
-
for dim in data_var.dims:
|
| 224 |
-
if dim != vertical_dim:
|
| 225 |
-
if 'time' in dim.lower():
|
| 226 |
-
data_var = data_var.isel({dim: min(time_idx, data_var.sizes[dim]-1)})
|
| 227 |
-
else:
|
| 228 |
-
dims_to_avg.append(dim)
|
| 229 |
-
|
| 230 |
-
if dims_to_avg:
|
| 231 |
-
profile = data_var.mean(dim=dims_to_avg)
|
| 232 |
else:
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
return fig
|
| 251 |
|
| 252 |
except Exception as e:
|
| 253 |
-
fig =
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
xref="paper", yref="paper",
|
| 258 |
-
showarrow=False,
|
| 259 |
-
font=dict(size=16, color="red")
|
| 260 |
-
)
|
| 261 |
-
return fig
|
| 262 |
|
| 263 |
-
def
|
| 264 |
-
"""
|
| 265 |
-
|
| 266 |
-
return "Please upload a NetCDF file.", {}, [], []
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.nc') as tmp_file:
|
| 271 |
-
tmp_file.write(file.read())
|
| 272 |
-
tmp_path = tmp_file.name
|
| 273 |
-
|
| 274 |
-
# Analyze the file
|
| 275 |
-
summary, info = analyze_netcdf(tmp_path)
|
| 276 |
-
|
| 277 |
-
# Get variable options
|
| 278 |
-
variable_options = list(info.get('data_vars_info', {}).keys())
|
| 279 |
-
|
| 280 |
-
# Get dimension options for slicing
|
| 281 |
-
dimensions = info.get('dimensions', {})
|
| 282 |
-
|
| 283 |
-
return summary, tmp_path, variable_options, list(dimensions.keys())
|
| 284 |
-
|
| 285 |
-
except Exception as e:
|
| 286 |
-
return f"Error processing file: {str(e)}", "", [], []
|
| 287 |
-
|
| 288 |
-
def update_plot(file_path: str, variable: str, plot_type: str, time_idx: int,
|
| 289 |
-
level_idx: int, colormap: str, aggregation_method: str):
|
| 290 |
-
"""Update plot based on user selections."""
|
| 291 |
-
if not file_path or not variable:
|
| 292 |
-
return go.Figure()
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
elif plot_type == "Time Series":
|
| 298 |
-
return create_time_series(file_path, variable, aggregation_method)
|
| 299 |
-
elif plot_type == "Vertical Profile":
|
| 300 |
-
return create_vertical_profile(file_path, variable, time_idx)
|
| 301 |
-
else:
|
| 302 |
-
return go.Figure()
|
| 303 |
-
except Exception as e:
|
| 304 |
-
fig = go.Figure()
|
| 305 |
-
fig.add_annotation(
|
| 306 |
-
text=f"Error: {str(e)}",
|
| 307 |
-
x=0.5, y=0.5,
|
| 308 |
-
xref="paper", yref="paper",
|
| 309 |
-
showarrow=False
|
| 310 |
-
)
|
| 311 |
-
return fig
|
| 312 |
-
|
| 313 |
-
# Create Gradio interface
|
| 314 |
-
with gr.Blocks(title="NetCDF Explorer π", theme=gr.themes.Soft()) as app:
|
| 315 |
-
gr.Markdown("""
|
| 316 |
-
# π NetCDF Explorer
|
| 317 |
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
- π Vertical profiles
|
| 324 |
-
- π¨ Customizable colormaps
|
| 325 |
-
- π Comprehensive metadata analysis
|
| 326 |
-
""")
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
file_types=[".nc", ".netcdf"],
|
| 333 |
-
type="binary"
|
| 334 |
-
)
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
-
# Control panel
|
| 341 |
with gr.Row():
|
| 342 |
with gr.Column(scale=1):
|
| 343 |
-
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
choices=[],
|
| 346 |
interactive=True
|
| 347 |
)
|
| 348 |
|
| 349 |
plot_type = gr.Radio(
|
| 350 |
label="Plot Type",
|
| 351 |
-
choices=["2D
|
| 352 |
-
value="2D
|
| 353 |
)
|
| 354 |
|
| 355 |
-
|
| 356 |
label="Colormap",
|
| 357 |
-
choices=["viridis", "plasma", "
|
| 358 |
-
"Blues", "Reds", "RdYlBu", "RdBu", "coolwarm"],
|
| 359 |
value="viridis"
|
| 360 |
)
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
value="mean",
|
| 366 |
-
visible=False
|
| 367 |
-
)
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
minimum=0,
|
| 373 |
-
maximum=100,
|
| 374 |
-
value=0,
|
| 375 |
-
step=1
|
| 376 |
-
)
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
update_btn = gr.Button("Update Plot", variant="primary")
|
| 387 |
-
|
| 388 |
-
# Plot display
|
| 389 |
-
plot_display = gr.Plot(label="Visualization")
|
| 390 |
|
| 391 |
-
# Hidden state
|
| 392 |
file_path_state = gr.State("")
|
| 393 |
|
| 394 |
# Event handlers
|
| 395 |
-
def
|
| 396 |
-
summary,
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
gr.update(value=tmp_path), # file_path_state
|
| 403 |
-
]
|
| 404 |
-
|
| 405 |
-
return updates
|
| 406 |
|
| 407 |
-
def
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
-
def
|
| 414 |
-
|
|
|
|
| 415 |
|
| 416 |
-
# Connect
|
| 417 |
file_upload.upload(
|
| 418 |
-
fn=
|
| 419 |
inputs=[file_upload],
|
| 420 |
-
outputs=[
|
| 421 |
)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
fn=
|
| 425 |
-
inputs=[
|
| 426 |
-
outputs=[
|
| 427 |
)
|
| 428 |
|
| 429 |
-
|
| 430 |
-
fn=
|
| 431 |
-
inputs=[
|
| 432 |
-
|
| 433 |
-
outputs=[plot_display]
|
| 434 |
)
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
level_slider, colormap_dropdown, aggregation_method],
|
| 441 |
-
outputs=[plot_display]
|
| 442 |
)
|
| 443 |
|
| 444 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""TensorView - Final version for Hugging Face deployment."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
import gradio as gr
|
|
|
|
|
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 8 |
+
|
| 9 |
+
# TensorView imports
|
| 10 |
+
from tensorview.io import open_any, list_variables, get_dataarray, close, DatasetHandle
|
| 11 |
+
from tensorview.plot import plot_1d, plot_2d, plot_map, export_fig
|
| 12 |
+
from tensorview.colors import get_matplotlib_colormaps
|
| 13 |
+
from tensorview.utils import identify_coordinates
|
| 14 |
|
| 15 |
+
# Global state
|
| 16 |
+
current_handle: Optional[DatasetHandle] = None
|
| 17 |
+
current_data: Optional[Dict[str, Any]] = None
|
| 18 |
|
| 19 |
+
def process_file_or_path(file_or_path) -> Tuple[str, List[str], str]:
|
| 20 |
+
"""Process uploaded file or file path."""
|
| 21 |
+
global current_handle, current_data
|
| 22 |
+
|
| 23 |
+
# Determine if it's a file object or path string
|
| 24 |
+
if hasattr(file_or_path, 'name'):
|
| 25 |
+
file_path = file_or_path.name
|
| 26 |
+
source = f"Uploaded: {os.path.basename(file_path)}"
|
| 27 |
+
else:
|
| 28 |
+
file_path = str(file_or_path)
|
| 29 |
+
source = f"Path: {os.path.basename(file_path)}"
|
| 30 |
+
|
| 31 |
try:
|
| 32 |
+
# Close previous handle
|
| 33 |
+
if current_handle:
|
| 34 |
+
close(current_handle)
|
| 35 |
+
|
| 36 |
+
# Open dataset
|
| 37 |
+
current_handle = open_any(file_path)
|
| 38 |
+
variables = list_variables(current_handle)
|
| 39 |
+
|
| 40 |
+
# Get variable info and dataset info
|
| 41 |
+
var_choices = [v.name for v in variables]
|
| 42 |
+
ds = current_handle.dataset
|
| 43 |
+
|
| 44 |
+
# Store dimension information
|
| 45 |
+
dims_info = {}
|
| 46 |
+
for dim_name, dim_size in ds.dims.items():
|
| 47 |
+
if dim_name in ds.coords:
|
| 48 |
+
coord = ds.coords[dim_name]
|
| 49 |
+
dims_info[dim_name] = {
|
| 50 |
+
'size': dim_size,
|
| 51 |
+
'values': coord.values,
|
| 52 |
+
'min': float(coord.min().values),
|
| 53 |
+
'max': float(coord.max().values)
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
current_data = {
|
| 57 |
+
'variables': {v.name: {'dims': v.dims, 'shape': v.shape, 'long_name': v.long_name, 'units': v.units}
|
| 58 |
+
for v in variables},
|
| 59 |
+
'dimensions': dims_info
|
| 60 |
}
|
| 61 |
|
| 62 |
+
# Create summary
|
| 63 |
+
summary = f"""β
**Dataset Loaded Successfully!**
|
| 64 |
+
|
| 65 |
+
**Source:** {source}
|
| 66 |
+
**Engine:** {current_handle.engine}
|
| 67 |
+
**Dimensions:** {len(ds.dims)} ({', '.join([f"{k}: {v}" for k, v in ds.dims.items()])})
|
| 68 |
+
**Variables:** {len(var_choices)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
### Available Variables:
|
| 71 |
"""
|
| 72 |
+
for v in variables:
|
| 73 |
+
summary += f"- **{v.name}**: {v.shape} [{v.units}] - {v.long_name}\n"
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
return summary, var_choices, file_path
|
|
|
|
| 76 |
|
| 77 |
except Exception as e:
|
| 78 |
+
return f"β **Error loading data:** {str(e)}", [], ""
|
| 79 |
|
| 80 |
+
def create_plot_with_auto_slicing(file_path: str, variable: str, plot_type: str,
|
| 81 |
+
colormap: str, *slider_values) -> Tuple[plt.Figure, str]:
|
| 82 |
+
"""Create plot with automatic dimension reduction."""
|
| 83 |
+
global current_handle, current_data
|
| 84 |
+
|
| 85 |
+
if not current_handle or not variable or not current_data:
|
| 86 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 87 |
+
ax.text(0.5, 0.5, 'π« No data loaded', ha='center', va='center',
|
| 88 |
+
transform=ax.transAxes, fontsize=16)
|
| 89 |
+
return fig, ""
|
| 90 |
+
|
| 91 |
try:
|
| 92 |
+
# Get data
|
| 93 |
+
da = get_dataarray(current_handle, variable)
|
| 94 |
+
var_info = current_data['variables'][variable]
|
|
|
|
| 95 |
|
| 96 |
+
# Apply dimension slicing
|
| 97 |
+
selection_dict = {}
|
| 98 |
+
slider_idx = 0
|
| 99 |
|
| 100 |
+
for dim in da.dims:
|
| 101 |
+
if any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
|
| 102 |
+
continue # Keep spatial dimensions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
# Use slider value or default to 0
|
| 105 |
+
if slider_idx < len(slider_values) and slider_values[slider_idx] is not None:
|
| 106 |
+
selection_dict[dim] = int(slider_values[slider_idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
+
selection_dict[dim] = 0
|
| 109 |
+
slider_idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Apply slicing
|
| 112 |
+
if selection_dict:
|
| 113 |
+
da_plot = da.isel(selection_dict)
|
|
|
|
|
|
|
|
|
|
| 114 |
else:
|
| 115 |
+
da_plot = da
|
| 116 |
|
| 117 |
+
# Calculate smart color limits using percentiles
|
| 118 |
+
data_values = da_plot.values
|
| 119 |
+
finite_data = data_values[np.isfinite(data_values)]
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
if len(finite_data) > 0:
|
| 122 |
+
data_min = float(np.percentile(finite_data, 2))
|
| 123 |
+
data_max = float(np.percentile(finite_data, 98))
|
| 124 |
+
|
| 125 |
+
# Ensure min != max
|
| 126 |
+
if abs(data_max - data_min) < 1e-10:
|
| 127 |
+
data_range = abs(data_min) * 0.1 if data_min != 0 else 1e-6
|
| 128 |
+
data_min -= data_range
|
| 129 |
+
data_max += data_range
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
else:
|
| 131 |
+
data_min, data_max = 0, 1
|
| 132 |
+
|
| 133 |
+
# Style parameters
|
| 134 |
+
style = {
|
| 135 |
+
'cmap': colormap,
|
| 136 |
+
'colorbar': True,
|
| 137 |
+
'vmin': data_min,
|
| 138 |
+
'vmax': data_max
|
| 139 |
+
}
|
| 140 |
|
| 141 |
+
# Create appropriate plot
|
| 142 |
+
n_dims = len(da_plot.dims)
|
| 143 |
+
|
| 144 |
+
if n_dims == 1:
|
| 145 |
+
fig = plot_1d(da_plot, **style)
|
| 146 |
+
elif n_dims == 2:
|
| 147 |
+
coords = identify_coordinates(da_plot)
|
| 148 |
+
if 'X' in coords and 'Y' in coords and plot_type == "Map":
|
| 149 |
+
style.update({
|
| 150 |
+
'proj': 'PlateCarree',
|
| 151 |
+
'coastlines': True,
|
| 152 |
+
'gridlines': True
|
| 153 |
+
})
|
| 154 |
+
fig = plot_map(da_plot, **style)
|
| 155 |
+
else:
|
| 156 |
+
fig = plot_2d(da_plot, kind="image", **style)
|
| 157 |
+
else:
|
| 158 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 159 |
+
ax.text(0.5, 0.5, f'π§ Cannot plot {n_dims}D data.\nAdjust sliders to reduce dimensions.',
|
| 160 |
+
ha='center', va='center', transform=ax.transAxes, fontsize=12)
|
| 161 |
+
return fig, f"β Too many dimensions: {n_dims}D"
|
| 162 |
+
|
| 163 |
+
# Add selection info to title
|
| 164 |
+
if selection_dict:
|
| 165 |
+
selection_str = ', '.join([f"{k}={v}" for k, v in selection_dict.items()])
|
| 166 |
+
fig.suptitle(f"{var_info['long_name']} ({selection_str})", fontsize=14)
|
| 167 |
+
|
| 168 |
+
# Create info string
|
| 169 |
+
data_info = f"""π **Plot Info:**
|
| 170 |
+
- **Shape:** {da_plot.shape}
|
| 171 |
+
- **Range:** {data_min:.4g} to {data_max:.4g} {var_info['units']}
|
| 172 |
+
- **Plot:** {plot_type} | **Colormap:** {colormap}
|
| 173 |
+
- **Valid data:** {len(finite_data):,} points
|
| 174 |
+
"""
|
| 175 |
|
| 176 |
+
return fig, data_info
|
|
|
|
| 177 |
|
| 178 |
except Exception as e:
|
| 179 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 180 |
+
ax.text(0.5, 0.5, f'β Error:\n{str(e)}', ha='center', va='center',
|
| 181 |
+
transform=ax.transAxes, fontsize=12, color='red')
|
| 182 |
+
return fig, f"β Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
def update_sliders_for_variable(variable: str):
|
| 185 |
+
"""Update sliders based on selected variable."""
|
| 186 |
+
global current_data
|
|
|
|
| 187 |
|
| 188 |
+
if not variable or not current_data:
|
| 189 |
+
return [gr.update(visible=False) for _ in range(8)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
+
var_info = current_data['variables'][variable]
|
| 192 |
+
dims = var_info['dims']
|
| 193 |
+
dims_info = current_data['dimensions']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
updates = []
|
| 196 |
+
for dim in dims:
|
| 197 |
+
if any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
|
| 198 |
+
continue # Skip spatial dimensions
|
| 199 |
+
|
| 200 |
+
if len(updates) < 8 and dim in dims_info:
|
| 201 |
+
dim_data = dims_info[dim]
|
| 202 |
+
updates.append(gr.update(
|
| 203 |
+
visible=True,
|
| 204 |
+
label=f"{dim} (0-{dim_data['size']-1})",
|
| 205 |
+
minimum=0,
|
| 206 |
+
maximum=dim_data['size'] - 1,
|
| 207 |
+
value=0,
|
| 208 |
+
step=1
|
| 209 |
+
))
|
| 210 |
|
| 211 |
+
# Hide remaining sliders
|
| 212 |
+
while len(updates) < 8:
|
| 213 |
+
updates.append(gr.update(visible=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
return updates
|
| 216 |
+
|
| 217 |
+
# Create Gradio Interface
|
| 218 |
+
with gr.Blocks(title="π TensorView - NetCDF/HDF/GRIB Viewer", theme=gr.themes.Ocean()) as app:
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
gr.HTML("""
|
| 221 |
+
<div style="text-align: center; padding: 20px;">
|
| 222 |
+
<h1>π TensorView</h1>
|
| 223 |
+
<p><strong>Interactive viewer for NetCDF, HDF, GRIB, and Zarr datasets</strong></p>
|
| 224 |
+
<p>Upload your data files or provide a file path, then explore with maps and plots!</p>
|
| 225 |
+
</div>
|
| 226 |
+
""")
|
| 227 |
|
|
|
|
| 228 |
with gr.Row():
|
| 229 |
with gr.Column(scale=1):
|
| 230 |
+
# Data Input Section
|
| 231 |
+
gr.Markdown("### π Data Input")
|
| 232 |
+
with gr.Tabs():
|
| 233 |
+
with gr.Tab("π€ Upload File"):
|
| 234 |
+
file_upload = gr.File(
|
| 235 |
+
label="Select data file",
|
| 236 |
+
file_types=[".nc", ".netcdf", ".hdf", ".h5", ".grib", ".grb"]
|
| 237 |
+
)
|
| 238 |
+
with gr.Tab("π File Path"):
|
| 239 |
+
file_path_input = gr.Textbox(
|
| 240 |
+
label="Enter file path",
|
| 241 |
+
placeholder="/path/to/your/data.nc",
|
| 242 |
+
lines=2
|
| 243 |
+
)
|
| 244 |
+
load_path_btn = gr.Button("Load File", variant="primary")
|
| 245 |
+
|
| 246 |
+
# Plot Configuration
|
| 247 |
+
gr.Markdown("### π¨ Plot Configuration")
|
| 248 |
+
variable_select = gr.Dropdown(
|
| 249 |
+
label="Variable",
|
| 250 |
choices=[],
|
| 251 |
interactive=True
|
| 252 |
)
|
| 253 |
|
| 254 |
plot_type = gr.Radio(
|
| 255 |
label="Plot Type",
|
| 256 |
+
choices=["2D Image", "Map"],
|
| 257 |
+
value="2D Image"
|
| 258 |
)
|
| 259 |
|
| 260 |
+
colormap_select = gr.Dropdown(
|
| 261 |
label="Colormap",
|
| 262 |
+
choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds", "turbo"],
|
|
|
|
| 263 |
value="viridis"
|
| 264 |
)
|
| 265 |
|
| 266 |
+
# Dimension Sliders
|
| 267 |
+
gr.Markdown("### ποΈ Dimension Controls")
|
| 268 |
+
gr.Markdown("*Sliders appear automatically based on selected variable*")
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
slider_components = [
|
| 271 |
+
gr.Slider(visible=False, interactive=True) for _ in range(8)
|
| 272 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
plot_btn = gr.Button("π¨ Create Plot", variant="primary", size="lg")
|
| 275 |
+
|
| 276 |
+
with gr.Column(scale=2):
|
| 277 |
+
# Results Section
|
| 278 |
+
data_info = gr.Markdown("*Load a dataset to begin*")
|
| 279 |
+
plot_display = gr.Plot(label="Visualization")
|
| 280 |
+
plot_info = gr.Markdown("")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
# Hidden state
|
| 283 |
file_path_state = gr.State("")
|
| 284 |
|
| 285 |
# Event handlers
|
| 286 |
+
def handle_file_upload(file):
|
| 287 |
+
summary, variables, path = process_file_or_path(file)
|
| 288 |
+
return (
|
| 289 |
+
summary,
|
| 290 |
+
gr.update(choices=variables, value=variables[0] if variables else None),
|
| 291 |
+
path
|
| 292 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
def handle_path_load(file_path):
|
| 295 |
+
summary, variables, path = process_file_or_path(file_path)
|
| 296 |
+
return (
|
| 297 |
+
summary,
|
| 298 |
+
gr.update(choices=variables, value=variables[0] if variables else None),
|
| 299 |
+
path
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def handle_variable_change(variable):
|
| 303 |
+
return update_sliders_for_variable(variable)
|
| 304 |
|
| 305 |
+
def handle_plot_creation(file_path, variable, plot_type, colormap, *slider_vals):
|
| 306 |
+
fig, info = create_plot_with_auto_slicing(file_path, variable, plot_type, colormap, *slider_vals)
|
| 307 |
+
return fig, info
|
| 308 |
|
| 309 |
+
# Connect events
|
| 310 |
file_upload.upload(
|
| 311 |
+
fn=handle_file_upload,
|
| 312 |
inputs=[file_upload],
|
| 313 |
+
outputs=[data_info, variable_select, file_path_state]
|
| 314 |
)
|
| 315 |
|
| 316 |
+
load_path_btn.click(
|
| 317 |
+
fn=handle_path_load,
|
| 318 |
+
inputs=[file_path_input],
|
| 319 |
+
outputs=[data_info, variable_select, file_path_state]
|
| 320 |
)
|
| 321 |
|
| 322 |
+
variable_select.change(
|
| 323 |
+
fn=handle_variable_change,
|
| 324 |
+
inputs=[variable_select],
|
| 325 |
+
outputs=slider_components
|
|
|
|
| 326 |
)
|
| 327 |
|
| 328 |
+
plot_btn.click(
|
| 329 |
+
fn=handle_plot_creation,
|
| 330 |
+
inputs=[file_path_state, variable_select, plot_type, colormap_select] + slider_components,
|
| 331 |
+
outputs=[plot_display, plot_info]
|
|
|
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
if __name__ == "__main__":
|
assets/colormaps/sample.cpt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sample GMT-style colormap
|
| 2 |
+
# Blue to Red temperature colormap
|
| 3 |
+
-10.0 0 0 255 -5.0 50 50 255
|
| 4 |
+
-5.0 50 50 255 0.0 100 100 255
|
| 5 |
+
0.0 100 100 255 5.0 150 200 255
|
| 6 |
+
5.0 150 200 255 10.0 200 255 200
|
| 7 |
+
10.0 200 255 200 15.0 255 255 100
|
| 8 |
+
15.0 255 255 100 20.0 255 200 50
|
| 9 |
+
20.0 255 200 50 25.0 255 150 0
|
| 10 |
+
25.0 255 150 0 30.0 255 100 0
|
| 11 |
+
30.0 255 100 0 35.0 255 50 0
|
| 12 |
+
35.0 255 50 0 40.0 255 0 0
|
requirements.txt
CHANGED
|
@@ -1,10 +1,20 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
xarray
|
| 2 |
+
dask[complete]
|
| 3 |
+
fsspec
|
| 4 |
+
s3fs
|
| 5 |
+
zarr
|
| 6 |
+
h5netcdf
|
| 7 |
+
netcdf4
|
| 8 |
+
cfgrib
|
| 9 |
+
eccodes
|
| 10 |
+
rioxarray
|
| 11 |
+
cartopy
|
| 12 |
+
pyproj
|
| 13 |
+
matplotlib
|
| 14 |
+
gradio>=4
|
| 15 |
+
numpy
|
| 16 |
+
pandas
|
| 17 |
+
pydantic
|
| 18 |
+
orjson
|
| 19 |
+
pytest
|
| 20 |
+
ruff
|
tensorview/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TensorView: A browser-based netCDF/HDF/GRIB/Zarr dataset viewer."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
tensorview/anim.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Animation functionality for creating MP4 videos from multi-dimensional data."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import subprocess
|
| 6 |
+
from typing import Optional, Callable, List
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from matplotlib.animation import FuncAnimation
|
| 10 |
+
import xarray as xr
|
| 11 |
+
|
| 12 |
+
from .plot import plot_1d, plot_2d, plot_map, setup_matplotlib
|
| 13 |
+
from .utils import identify_coordinates, format_value
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def check_ffmpeg():
|
| 17 |
+
"""Check if FFmpeg is available."""
|
| 18 |
+
try:
|
| 19 |
+
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
|
| 20 |
+
return True
|
| 21 |
+
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def animate_over_dim(da: xr.DataArray, dim: str, plot_func: Callable = None,
|
| 26 |
+
fps: int = 10, out: str = "animation.mp4",
|
| 27 |
+
figsize: tuple = (10, 8), **plot_kwargs) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Create an animation over a specified dimension.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
da: Input DataArray
|
| 33 |
+
dim: Dimension to animate over
|
| 34 |
+
plot_func: Plotting function to use (auto-detected if None)
|
| 35 |
+
fps: Frames per second
|
| 36 |
+
out: Output file path
|
| 37 |
+
figsize: Figure size
|
| 38 |
+
**plot_kwargs: Additional plotting parameters
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Path to the created animation file
|
| 42 |
+
"""
|
| 43 |
+
if not check_ffmpeg():
|
| 44 |
+
raise RuntimeError("FFmpeg is required for creating MP4 animations")
|
| 45 |
+
|
| 46 |
+
if dim not in da.dims:
|
| 47 |
+
raise ValueError(f"Dimension '{dim}' not found in DataArray")
|
| 48 |
+
|
| 49 |
+
setup_matplotlib()
|
| 50 |
+
|
| 51 |
+
# Get coordinate values for the animation dimension
|
| 52 |
+
coord_vals = da.coords[dim].values
|
| 53 |
+
n_frames = len(coord_vals)
|
| 54 |
+
|
| 55 |
+
if n_frames < 2:
|
| 56 |
+
raise ValueError(f"Need at least 2 frames for animation, got {n_frames}")
|
| 57 |
+
|
| 58 |
+
# Auto-detect plot function if not provided
|
| 59 |
+
if plot_func is None:
|
| 60 |
+
remaining_dims = [d for d in da.dims if d != dim]
|
| 61 |
+
n_remaining = len(remaining_dims)
|
| 62 |
+
|
| 63 |
+
# Check if we have geographic coordinates
|
| 64 |
+
coords = identify_coordinates(da)
|
| 65 |
+
has_geo = 'X' in coords and 'Y' in coords
|
| 66 |
+
|
| 67 |
+
if n_remaining == 1:
|
| 68 |
+
plot_func = plot_1d
|
| 69 |
+
elif n_remaining == 2 and has_geo:
|
| 70 |
+
plot_func = plot_map
|
| 71 |
+
elif n_remaining == 2:
|
| 72 |
+
plot_func = plot_2d
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data")
|
| 75 |
+
|
| 76 |
+
# Create figure and initial plot
|
| 77 |
+
fig, ax = plt.subplots(figsize=figsize)
|
| 78 |
+
|
| 79 |
+
# Get initial frame
|
| 80 |
+
initial_frame = da.isel({dim: 0})
|
| 81 |
+
|
| 82 |
+
# Set up consistent color limits across all frames
|
| 83 |
+
if 'vmin' not in plot_kwargs:
|
| 84 |
+
plot_kwargs['vmin'] = float(da.min().values)
|
| 85 |
+
if 'vmax' not in plot_kwargs:
|
| 86 |
+
plot_kwargs['vmax'] = float(da.max().values)
|
| 87 |
+
|
| 88 |
+
# Create initial plot to get the structure
|
| 89 |
+
if plot_func == plot_1d:
|
| 90 |
+
line, = ax.plot([], [])
|
| 91 |
+
ax.set_xlim(float(initial_frame.coords[initial_frame.dims[0]].min()),
|
| 92 |
+
float(initial_frame.coords[initial_frame.dims[0]].max()))
|
| 93 |
+
ax.set_ylim(plot_kwargs['vmin'], plot_kwargs['vmax'])
|
| 94 |
+
|
| 95 |
+
# Set labels
|
| 96 |
+
x_dim = initial_frame.dims[0]
|
| 97 |
+
ax.set_xlabel(f"{x_dim} ({initial_frame.coords[x_dim].attrs.get('units', '')})")
|
| 98 |
+
ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
|
| 99 |
+
|
| 100 |
+
def animate(frame_idx):
|
| 101 |
+
frame_data = da.isel({dim: frame_idx})
|
| 102 |
+
x_data = frame_data.coords[x_dim]
|
| 103 |
+
line.set_data(x_data, frame_data)
|
| 104 |
+
|
| 105 |
+
# Update title with current time/coordinate value
|
| 106 |
+
coord_val = coord_vals[frame_idx]
|
| 107 |
+
coord_str = format_value(coord_val, dim)
|
| 108 |
+
title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}"
|
| 109 |
+
ax.set_title(title)
|
| 110 |
+
|
| 111 |
+
return line,
|
| 112 |
+
|
| 113 |
+
elif plot_func in [plot_2d, plot_map]:
|
| 114 |
+
# For 2D plots, we need to recreate the plot each frame
|
| 115 |
+
def animate(frame_idx):
|
| 116 |
+
ax.clear()
|
| 117 |
+
frame_data = da.isel({dim: frame_idx})
|
| 118 |
+
|
| 119 |
+
# Create the plot
|
| 120 |
+
if plot_func == plot_map:
|
| 121 |
+
# Special handling for map plots
|
| 122 |
+
import cartopy.crs as ccrs
|
| 123 |
+
import cartopy.feature as cfeature
|
| 124 |
+
|
| 125 |
+
proj = plot_kwargs.get('proj', 'PlateCarree')
|
| 126 |
+
proj_map = {
|
| 127 |
+
'PlateCarree': ccrs.PlateCarree(),
|
| 128 |
+
'Robinson': ccrs.Robinson(),
|
| 129 |
+
'Mollweide': ccrs.Mollweide()
|
| 130 |
+
}
|
| 131 |
+
projection = proj_map.get(proj, ccrs.PlateCarree())
|
| 132 |
+
|
| 133 |
+
coords = identify_coordinates(frame_data)
|
| 134 |
+
lon_dim = coords['X']
|
| 135 |
+
lat_dim = coords['Y']
|
| 136 |
+
|
| 137 |
+
lons = frame_data.coords[lon_dim].values
|
| 138 |
+
lats = frame_data.coords[lat_dim].values
|
| 139 |
+
|
| 140 |
+
# Create pcolormesh plot
|
| 141 |
+
cmap = plot_kwargs.get('cmap', 'viridis')
|
| 142 |
+
im = ax.pcolormesh(lons, lats, frame_data.transpose(lat_dim, lon_dim).values,
|
| 143 |
+
cmap=cmap, vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'],
|
| 144 |
+
transform=ccrs.PlateCarree(), shading='auto')
|
| 145 |
+
|
| 146 |
+
# Add map features
|
| 147 |
+
if plot_kwargs.get('coastlines', True):
|
| 148 |
+
ax.coastlines(resolution='50m', color='black', linewidth=0.5)
|
| 149 |
+
if plot_kwargs.get('gridlines', True):
|
| 150 |
+
ax.gridlines(alpha=0.5)
|
| 151 |
+
|
| 152 |
+
ax.set_global()
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
# Regular 2D plot
|
| 156 |
+
coords = identify_coordinates(frame_data)
|
| 157 |
+
x_dim = coords.get('X', frame_data.dims[-1])
|
| 158 |
+
y_dim = coords.get('Y', frame_data.dims[-2])
|
| 159 |
+
|
| 160 |
+
frame_plot = frame_data.transpose(y_dim, x_dim)
|
| 161 |
+
x_coord = frame_data.coords[x_dim]
|
| 162 |
+
y_coord = frame_data.coords[y_dim]
|
| 163 |
+
|
| 164 |
+
im = ax.imshow(frame_plot.values,
|
| 165 |
+
extent=[float(x_coord.min()), float(x_coord.max()),
|
| 166 |
+
float(y_coord.min()), float(y_coord.max())],
|
| 167 |
+
aspect='auto', origin='lower',
|
| 168 |
+
cmap=plot_kwargs.get('cmap', 'viridis'),
|
| 169 |
+
vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'])
|
| 170 |
+
|
| 171 |
+
ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})")
|
| 172 |
+
ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})")
|
| 173 |
+
|
| 174 |
+
# Update title
|
| 175 |
+
coord_val = coord_vals[frame_idx]
|
| 176 |
+
coord_str = format_value(coord_val, dim)
|
| 177 |
+
title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}"
|
| 178 |
+
ax.set_title(title)
|
| 179 |
+
|
| 180 |
+
return [im] if 'im' in locals() else []
|
| 181 |
+
|
| 182 |
+
# Create animation
|
| 183 |
+
anim = FuncAnimation(fig, animate, frames=n_frames, interval=1000//fps, blit=False)
|
| 184 |
+
|
| 185 |
+
# Save animation
|
| 186 |
+
try:
|
| 187 |
+
# Use FFmpeg writer
|
| 188 |
+
Writer = plt.matplotlib.animation.writers['ffmpeg']
|
| 189 |
+
writer = Writer(fps=fps, metadata=dict(artist='TensorView'), bitrate=1800)
|
| 190 |
+
anim.save(out, writer=writer)
|
| 191 |
+
|
| 192 |
+
plt.close(fig)
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
plt.close(fig)
|
| 197 |
+
raise RuntimeError(f"Failed to create animation: {str(e)}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def create_frame_sequence(da: xr.DataArray, dim: str, plot_func: Callable = None,
|
| 201 |
+
output_dir: str = "frames", **plot_kwargs) -> List[str]:
|
| 202 |
+
"""
|
| 203 |
+
Create a sequence of individual frame images.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
da: Input DataArray
|
| 207 |
+
dim: Dimension to animate over
|
| 208 |
+
plot_func: Plotting function to use
|
| 209 |
+
output_dir: Directory to save frames
|
| 210 |
+
**plot_kwargs: Additional plotting parameters
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
List of frame file paths
|
| 214 |
+
"""
|
| 215 |
+
if dim not in da.dims:
|
| 216 |
+
raise ValueError(f"Dimension '{dim}' not found in DataArray")
|
| 217 |
+
|
| 218 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
coord_vals = da.coords[dim].values
|
| 221 |
+
frame_paths = []
|
| 222 |
+
|
| 223 |
+
# Auto-detect plot function if not provided
|
| 224 |
+
if plot_func is None:
|
| 225 |
+
remaining_dims = [d for d in da.dims if d != dim]
|
| 226 |
+
n_remaining = len(remaining_dims)
|
| 227 |
+
|
| 228 |
+
coords = identify_coordinates(da)
|
| 229 |
+
has_geo = 'X' in coords and 'Y' in coords
|
| 230 |
+
|
| 231 |
+
if n_remaining == 1:
|
| 232 |
+
plot_func = plot_1d
|
| 233 |
+
elif n_remaining == 2 and has_geo:
|
| 234 |
+
plot_func = plot_map
|
| 235 |
+
elif n_remaining == 2:
|
| 236 |
+
plot_func = plot_2d
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data")
|
| 239 |
+
|
| 240 |
+
# Set consistent color limits
|
| 241 |
+
if 'vmin' not in plot_kwargs:
|
| 242 |
+
plot_kwargs['vmin'] = float(da.min().values)
|
| 243 |
+
if 'vmax' not in plot_kwargs:
|
| 244 |
+
plot_kwargs['vmax'] = float(da.max().values)
|
| 245 |
+
|
| 246 |
+
# Create frames
|
| 247 |
+
for i, coord_val in enumerate(coord_vals):
|
| 248 |
+
frame_data = da.isel({dim: i})
|
| 249 |
+
|
| 250 |
+
# Create plot
|
| 251 |
+
fig = plot_func(frame_data, **plot_kwargs)
|
| 252 |
+
|
| 253 |
+
# Update title with coordinate value
|
| 254 |
+
coord_str = format_value(coord_val, dim)
|
| 255 |
+
fig.suptitle(f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}")
|
| 256 |
+
|
| 257 |
+
# Save frame
|
| 258 |
+
frame_path = os.path.join(output_dir, f"frame_{i:04d}.png")
|
| 259 |
+
fig.savefig(frame_path, dpi=150, bbox_inches='tight')
|
| 260 |
+
frame_paths.append(frame_path)
|
| 261 |
+
|
| 262 |
+
plt.close(fig)
|
| 263 |
+
|
| 264 |
+
return frame_paths
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def frames_to_mp4(frame_dir: str, output_path: str, fps: int = 10, cleanup: bool = True) -> str:
|
| 268 |
+
"""
|
| 269 |
+
Convert a directory of frame images to MP4 video.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
frame_dir: Directory containing frame images
|
| 273 |
+
output_path: Output MP4 file path
|
| 274 |
+
fps: Frames per second
|
| 275 |
+
cleanup: Whether to delete frame files after conversion
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Path to created MP4 file
|
| 279 |
+
"""
|
| 280 |
+
if not check_ffmpeg():
|
| 281 |
+
raise RuntimeError("FFmpeg is required for MP4 conversion")
|
| 282 |
+
|
| 283 |
+
# Build FFmpeg command
|
| 284 |
+
cmd = [
|
| 285 |
+
'ffmpeg', '-y', # Overwrite output
|
| 286 |
+
'-framerate', str(fps),
|
| 287 |
+
'-pattern_type', 'glob',
|
| 288 |
+
'-i', os.path.join(frame_dir, 'frame_*.png'),
|
| 289 |
+
'-c:v', 'libx264',
|
| 290 |
+
'-pix_fmt', 'yuv420p',
|
| 291 |
+
'-crf', '18', # High quality
|
| 292 |
+
output_path
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
subprocess.run(cmd, check=True, capture_output=True)
|
| 297 |
+
|
| 298 |
+
# Clean up frame files if requested
|
| 299 |
+
if cleanup:
|
| 300 |
+
import glob
|
| 301 |
+
for frame_file in glob.glob(os.path.join(frame_dir, 'frame_*.png')):
|
| 302 |
+
os.remove(frame_file)
|
| 303 |
+
|
| 304 |
+
# Remove directory if empty
|
| 305 |
+
try:
|
| 306 |
+
os.rmdir(frame_dir)
|
| 307 |
+
except OSError:
|
| 308 |
+
pass # Directory not empty
|
| 309 |
+
|
| 310 |
+
return output_path
|
| 311 |
+
|
| 312 |
+
except subprocess.CalledProcessError as e:
|
| 313 |
+
raise RuntimeError(f"FFmpeg failed: {e.stderr.decode()}")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def create_gif(da: xr.DataArray, dim: str, output_path: str = "animation.gif",
|
| 317 |
+
duration: int = 200, plot_func: Callable = None, **plot_kwargs) -> str:
|
| 318 |
+
"""
|
| 319 |
+
Create an animated GIF.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
da: Input DataArray
|
| 323 |
+
dim: Dimension to animate over
|
| 324 |
+
output_path: Output GIF file path
|
| 325 |
+
duration: Duration per frame in milliseconds
|
| 326 |
+
plot_func: Plotting function to use
|
| 327 |
+
**plot_kwargs: Additional plotting parameters
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
Path to created GIF file
|
| 331 |
+
"""
|
| 332 |
+
try:
|
| 333 |
+
from PIL import Image
|
| 334 |
+
except ImportError:
|
| 335 |
+
raise ImportError("Pillow is required for GIF creation")
|
| 336 |
+
|
| 337 |
+
# Create frame sequence
|
| 338 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 339 |
+
frame_paths = create_frame_sequence(da, dim, plot_func, temp_dir, **plot_kwargs)
|
| 340 |
+
|
| 341 |
+
# Load frames and create GIF
|
| 342 |
+
images = []
|
| 343 |
+
for frame_path in frame_paths:
|
| 344 |
+
img = Image.open(frame_path)
|
| 345 |
+
images.append(img)
|
| 346 |
+
|
| 347 |
+
# Save as GIF
|
| 348 |
+
images[0].save(output_path, save_all=True, append_images=images[1:],
|
| 349 |
+
duration=duration, loop=0)
|
| 350 |
+
|
| 351 |
+
return output_path
|
tensorview/colors.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Colormap loading and management for CPT/ACT/RGB files."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib.colors as mcolors
|
| 9 |
+
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_cpt_file(filepath: str) -> LinearSegmentedColormap:
|
| 13 |
+
"""
|
| 14 |
+
Load a GMT-style CPT (Color Palette Table) file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
filepath: Path to the CPT file
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
matplotlib LinearSegmentedColormap
|
| 21 |
+
"""
|
| 22 |
+
colors = []
|
| 23 |
+
positions = []
|
| 24 |
+
|
| 25 |
+
with open(filepath, 'r') as f:
|
| 26 |
+
lines = f.readlines()
|
| 27 |
+
|
| 28 |
+
# Parse CPT format
|
| 29 |
+
for line in lines:
|
| 30 |
+
line = line.strip()
|
| 31 |
+
|
| 32 |
+
# Skip comments and empty lines
|
| 33 |
+
if line.startswith('#') or not line or line.startswith('B') or line.startswith('F') or line.startswith('N'):
|
| 34 |
+
continue
|
| 35 |
+
|
| 36 |
+
parts = line.split()
|
| 37 |
+
if len(parts) >= 8:
|
| 38 |
+
# CPT format: z0 r0 g0 b0 z1 r1 g1 b1
|
| 39 |
+
try:
|
| 40 |
+
z0, r0, g0, b0, z1, r1, g1, b1 = map(float, parts[:8])
|
| 41 |
+
|
| 42 |
+
# Normalize RGB values if they're in 0-255 range
|
| 43 |
+
if r0 > 1 or g0 > 1 or b0 > 1:
|
| 44 |
+
r0, g0, b0 = r0/255, g0/255, b0/255
|
| 45 |
+
if r1 > 1 or g1 > 1 or b1 > 1:
|
| 46 |
+
r1, g1, b1 = r1/255, g1/255, b1/255
|
| 47 |
+
|
| 48 |
+
colors.extend([(r0, g0, b0), (r1, g1, b1)])
|
| 49 |
+
positions.extend([z0, z1])
|
| 50 |
+
|
| 51 |
+
except ValueError:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
if not colors:
|
| 55 |
+
raise ValueError(f"No valid color data found in {filepath}")
|
| 56 |
+
|
| 57 |
+
# Normalize positions to 0-1 range
|
| 58 |
+
positions = np.array(positions)
|
| 59 |
+
if len(set(positions)) > 1:
|
| 60 |
+
positions = (positions - positions.min()) / (positions.max() - positions.min())
|
| 61 |
+
else:
|
| 62 |
+
positions = np.linspace(0, 1, len(positions))
|
| 63 |
+
|
| 64 |
+
# Create colormap
|
| 65 |
+
cmap_data = list(zip(positions, colors))
|
| 66 |
+
cmap_data.sort(key=lambda x: x[0]) # Sort by position
|
| 67 |
+
|
| 68 |
+
# Remove duplicates
|
| 69 |
+
unique_data = []
|
| 70 |
+
seen_positions = set()
|
| 71 |
+
for pos, color in cmap_data:
|
| 72 |
+
if pos not in seen_positions:
|
| 73 |
+
unique_data.append((pos, color))
|
| 74 |
+
seen_positions.add(pos)
|
| 75 |
+
|
| 76 |
+
positions, colors = zip(*unique_data)
|
| 77 |
+
|
| 78 |
+
# Create the colormap
|
| 79 |
+
name = os.path.splitext(os.path.basename(filepath))[0]
|
| 80 |
+
return LinearSegmentedColormap.from_list(name, list(zip(positions, colors)))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def parse_act_file(filepath: str) -> ListedColormap:
|
| 84 |
+
"""
|
| 85 |
+
Load an Adobe Color Table (ACT) file.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
filepath: Path to the ACT file
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
matplotlib ListedColormap
|
| 92 |
+
"""
|
| 93 |
+
with open(filepath, 'rb') as f:
|
| 94 |
+
data = f.read()
|
| 95 |
+
|
| 96 |
+
# ACT files are 768 bytes (256 colors * 3 RGB values)
|
| 97 |
+
if len(data) != 768:
|
| 98 |
+
raise ValueError(f"Invalid ACT file size: {len(data)} bytes (expected 768)")
|
| 99 |
+
|
| 100 |
+
colors = []
|
| 101 |
+
for i in range(0, 768, 3):
|
| 102 |
+
r, g, b = data[i], data[i+1], data[i+2]
|
| 103 |
+
colors.append((r/255.0, g/255.0, b/255.0))
|
| 104 |
+
|
| 105 |
+
name = os.path.splitext(os.path.basename(filepath))[0]
|
| 106 |
+
return ListedColormap(colors, name=name)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def parse_rgb_file(filepath: str) -> ListedColormap:
|
| 110 |
+
"""
|
| 111 |
+
Load a simple RGB text file (one color per line).
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
filepath: Path to the RGB file
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
matplotlib ListedColormap
|
| 118 |
+
"""
|
| 119 |
+
colors = []
|
| 120 |
+
|
| 121 |
+
with open(filepath, 'r') as f:
|
| 122 |
+
lines = f.readlines()
|
| 123 |
+
|
| 124 |
+
for line in lines:
|
| 125 |
+
line = line.strip()
|
| 126 |
+
|
| 127 |
+
# Skip comments and empty lines
|
| 128 |
+
if line.startswith('#') or not line:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
# Parse RGB values
|
| 132 |
+
parts = line.split()
|
| 133 |
+
if len(parts) >= 3:
|
| 134 |
+
try:
|
| 135 |
+
r, g, b = map(float, parts[:3])
|
| 136 |
+
|
| 137 |
+
# Normalize RGB values if they're in 0-255 range
|
| 138 |
+
if r > 1 or g > 1 or b > 1:
|
| 139 |
+
r, g, b = r/255, g/255, b/255
|
| 140 |
+
|
| 141 |
+
colors.append((r, g, b))
|
| 142 |
+
except ValueError:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
if not colors:
|
| 146 |
+
raise ValueError(f"No valid color data found in {filepath}")
|
| 147 |
+
|
| 148 |
+
name = os.path.splitext(os.path.basename(filepath))[0]
|
| 149 |
+
return ListedColormap(colors, name=name)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def load_cpt(filepath: str) -> LinearSegmentedColormap:
|
| 153 |
+
"""Load a CPT colormap file."""
|
| 154 |
+
return parse_cpt_file(filepath)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def load_act(filepath: str) -> ListedColormap:
|
| 158 |
+
"""Load an ACT colormap file."""
|
| 159 |
+
return parse_act_file(filepath)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def load_rgb(filepath: str) -> ListedColormap:
|
| 163 |
+
"""Load an RGB colormap file."""
|
| 164 |
+
return parse_rgb_file(filepath)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_colormap(filepath: str) -> mcolors.Colormap:
|
| 168 |
+
"""
|
| 169 |
+
Load a colormap from file, auto-detecting the format.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
filepath: Path to the colormap file
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
matplotlib Colormap
|
| 176 |
+
"""
|
| 177 |
+
_, ext = os.path.splitext(filepath.lower())
|
| 178 |
+
|
| 179 |
+
if ext == '.cpt':
|
| 180 |
+
return load_cpt(filepath)
|
| 181 |
+
elif ext == '.act':
|
| 182 |
+
return load_act(filepath)
|
| 183 |
+
elif ext in ['.rgb', '.txt']:
|
| 184 |
+
return load_rgb(filepath)
|
| 185 |
+
else:
|
| 186 |
+
# Try to guess based on content
|
| 187 |
+
try:
|
| 188 |
+
return load_cpt(filepath)
|
| 189 |
+
except:
|
| 190 |
+
try:
|
| 191 |
+
return load_rgb(filepath)
|
| 192 |
+
except:
|
| 193 |
+
raise ValueError(f"Cannot determine colormap format for {filepath}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_matplotlib_colormaps() -> List[str]:
|
| 197 |
+
"""Get list of available matplotlib colormaps."""
|
| 198 |
+
return sorted(plt.colormaps())
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def create_diverging_colormap(name: str, colors: List[str], center: float = 0.5) -> LinearSegmentedColormap:
|
| 202 |
+
"""
|
| 203 |
+
Create a diverging colormap.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
name: Name for the colormap
|
| 207 |
+
colors: List of color names/hex codes
|
| 208 |
+
center: Position of the center color (0-1)
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
LinearSegmentedColormap
|
| 212 |
+
"""
|
| 213 |
+
return LinearSegmentedColormap.from_list(name, colors)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def reverse_colormap(cmap: mcolors.Colormap) -> mcolors.Colormap:
|
| 217 |
+
"""Reverse a colormap."""
|
| 218 |
+
if hasattr(cmap, 'reversed'):
|
| 219 |
+
return cmap.reversed()
|
| 220 |
+
else:
|
| 221 |
+
# For custom colormaps
|
| 222 |
+
if isinstance(cmap, LinearSegmentedColormap):
|
| 223 |
+
return LinearSegmentedColormap.from_list(
|
| 224 |
+
f"{cmap.name}_r",
|
| 225 |
+
cmap(np.linspace(0, 1, 256))[::-1]
|
| 226 |
+
)
|
| 227 |
+
elif isinstance(cmap, ListedColormap):
|
| 228 |
+
return ListedColormap(
|
| 229 |
+
cmap.colors[::-1],
|
| 230 |
+
name=f"{cmap.name}_r"
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
return cmap
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def colormap_to_cpt(cmap: mcolors.Colormap, filepath: str, n_colors: int = 256):
|
| 237 |
+
"""
|
| 238 |
+
Save a matplotlib colormap as a CPT file.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
cmap: matplotlib Colormap
|
| 242 |
+
filepath: Output file path
|
| 243 |
+
n_colors: Number of color levels
|
| 244 |
+
"""
|
| 245 |
+
colors = cmap(np.linspace(0, 1, n_colors))
|
| 246 |
+
|
| 247 |
+
with open(filepath, 'w') as f:
|
| 248 |
+
f.write(f"# Color palette: {cmap.name}\n")
|
| 249 |
+
f.write("# Created by TensorView\n")
|
| 250 |
+
|
| 251 |
+
for i in range(len(colors) - 1):
|
| 252 |
+
r0, g0, b0 = colors[i][:3]
|
| 253 |
+
r1, g1, b1 = colors[i+1][:3]
|
| 254 |
+
|
| 255 |
+
z0 = i / (n_colors - 1)
|
| 256 |
+
z1 = (i + 1) / (n_colors - 1)
|
| 257 |
+
|
| 258 |
+
f.write(f"{z0:.6f}\t{r0*255:.0f}\t{g0*255:.0f}\t{b0*255:.0f}\t")
|
| 259 |
+
f.write(f"{z1:.6f}\t{r1*255:.0f}\t{g1*255:.0f}\t{b1*255:.0f}\n")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_colormap_info(cmap: mcolors.Colormap) -> dict:
|
| 263 |
+
"""Get information about a colormap."""
|
| 264 |
+
info = {
|
| 265 |
+
'name': getattr(cmap, 'name', 'unknown'),
|
| 266 |
+
'type': type(cmap).__name__,
|
| 267 |
+
'n_colors': getattr(cmap, 'N', 'continuous')
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
if isinstance(cmap, ListedColormap):
|
| 271 |
+
info['n_colors'] = len(cmap.colors)
|
| 272 |
+
|
| 273 |
+
return info
|
tensorview/grid.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grid alignment and combination operations."""
|
| 2 |
+
|
| 3 |
+
from typing import Literal, Dict, Any, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
import xarray as xr
|
| 6 |
+
from .utils import identify_coordinates, get_crs, is_geographic
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]:
|
| 10 |
+
"""
|
| 11 |
+
Align two DataArrays for combination operations.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
a, b: Input DataArrays
|
| 15 |
+
method: Alignment method ('reindex', 'interp')
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tuple of aligned DataArrays
|
| 19 |
+
"""
|
| 20 |
+
# Check CRS compatibility
|
| 21 |
+
crs_a = get_crs(a)
|
| 22 |
+
crs_b = get_crs(b)
|
| 23 |
+
|
| 24 |
+
if crs_a and crs_b and not crs_a.equals(crs_b):
|
| 25 |
+
raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}")
|
| 26 |
+
|
| 27 |
+
# Get coordinate information
|
| 28 |
+
coords_a = identify_coordinates(a)
|
| 29 |
+
coords_b = identify_coordinates(b)
|
| 30 |
+
|
| 31 |
+
# Find common dimensions
|
| 32 |
+
common_dims = set(a.dims) & set(b.dims)
|
| 33 |
+
|
| 34 |
+
if not common_dims:
|
| 35 |
+
raise ValueError("No common dimensions found for alignment")
|
| 36 |
+
|
| 37 |
+
# Align coordinates
|
| 38 |
+
if method == "reindex":
|
| 39 |
+
# Use nearest neighbor reindexing
|
| 40 |
+
a_aligned = a
|
| 41 |
+
b_aligned = b
|
| 42 |
+
|
| 43 |
+
for dim in common_dims:
|
| 44 |
+
if dim in a.dims and dim in b.dims:
|
| 45 |
+
# Get the union of coordinates for this dimension
|
| 46 |
+
coord_a = a.coords[dim]
|
| 47 |
+
coord_b = b.coords[dim]
|
| 48 |
+
|
| 49 |
+
# Use the coordinate with higher resolution
|
| 50 |
+
if len(coord_a) >= len(coord_b):
|
| 51 |
+
target_coord = coord_a
|
| 52 |
+
else:
|
| 53 |
+
target_coord = coord_b
|
| 54 |
+
|
| 55 |
+
# Reindex both arrays to the target coordinate
|
| 56 |
+
a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest')
|
| 57 |
+
b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest')
|
| 58 |
+
|
| 59 |
+
elif method == "interp":
|
| 60 |
+
# Use interpolation
|
| 61 |
+
# Find common coordinate grid
|
| 62 |
+
common_coords = {}
|
| 63 |
+
for dim in common_dims:
|
| 64 |
+
if dim in a.dims and dim in b.dims:
|
| 65 |
+
coord_a = a.coords[dim]
|
| 66 |
+
coord_b = b.coords[dim]
|
| 67 |
+
|
| 68 |
+
# Create a common grid (intersection)
|
| 69 |
+
min_val = max(float(coord_a.min()), float(coord_b.min()))
|
| 70 |
+
max_val = min(float(coord_a.max()), float(coord_b.max()))
|
| 71 |
+
|
| 72 |
+
# Use the finer resolution
|
| 73 |
+
res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0
|
| 74 |
+
res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0
|
| 75 |
+
res = min(abs(res_a), abs(res_b))
|
| 76 |
+
|
| 77 |
+
common_coords[dim] = np.arange(min_val, max_val + res, res)
|
| 78 |
+
|
| 79 |
+
a_aligned = a.interp(common_coords)
|
| 80 |
+
b_aligned = b.interp(common_coords)
|
| 81 |
+
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"Unknown alignment method: {method}")
|
| 84 |
+
|
| 85 |
+
return a_aligned, b_aligned
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray:
|
| 89 |
+
"""
|
| 90 |
+
Combine two DataArrays with the specified operation.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
a, b: Input DataArrays
|
| 94 |
+
op: Operation ('sum', 'avg', 'diff')
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Combined DataArray
|
| 98 |
+
"""
|
| 99 |
+
# Align the arrays first
|
| 100 |
+
a_aligned, b_aligned = align_for_combine(a, b)
|
| 101 |
+
|
| 102 |
+
# Perform the operation
|
| 103 |
+
if op == "sum":
|
| 104 |
+
result = a_aligned + b_aligned
|
| 105 |
+
elif op == "avg":
|
| 106 |
+
result = (a_aligned + b_aligned) / 2
|
| 107 |
+
elif op == "diff":
|
| 108 |
+
result = a_aligned - b_aligned
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown operation: {op}")
|
| 111 |
+
|
| 112 |
+
# Update attributes
|
| 113 |
+
result.name = f"{a.name}_{op}_{b.name}"
|
| 114 |
+
|
| 115 |
+
if op == "sum":
|
| 116 |
+
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}"
|
| 117 |
+
elif op == "avg":
|
| 118 |
+
result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}"
|
| 119 |
+
elif op == "diff":
|
| 120 |
+
result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}"
|
| 121 |
+
|
| 122 |
+
# Preserve units if they match
|
| 123 |
+
if a.attrs.get('units') == b.attrs.get('units'):
|
| 124 |
+
result.attrs['units'] = a.attrs.get('units', '')
|
| 125 |
+
|
| 126 |
+
return result
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray:
|
| 130 |
+
"""
|
| 131 |
+
Create a cross-section of the DataArray.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
da: Input DataArray
|
| 135 |
+
along: Dimension to keep for the section (e.g., 'time', 'lat')
|
| 136 |
+
fixed: Dictionary of {dim: value} for dimensions to fix
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Cross-section DataArray
|
| 140 |
+
"""
|
| 141 |
+
if along not in da.dims:
|
| 142 |
+
raise ValueError(f"Dimension '{along}' not found in DataArray")
|
| 143 |
+
|
| 144 |
+
# Start with the full array
|
| 145 |
+
result = da
|
| 146 |
+
|
| 147 |
+
# Apply fixed selections
|
| 148 |
+
selection = {}
|
| 149 |
+
for dim, value in fixed.items():
|
| 150 |
+
if dim not in da.dims:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
coord = da.coords[dim]
|
| 154 |
+
|
| 155 |
+
if isinstance(value, (int, float)):
|
| 156 |
+
# Select nearest value
|
| 157 |
+
selection[dim] = coord.sel({dim: value}, method='nearest')
|
| 158 |
+
elif isinstance(value, str) and 'time' in dim.lower():
|
| 159 |
+
# Handle time strings
|
| 160 |
+
selection[dim] = value
|
| 161 |
+
else:
|
| 162 |
+
selection[dim] = value
|
| 163 |
+
|
| 164 |
+
if selection:
|
| 165 |
+
result = result.sel(selection, method='nearest')
|
| 166 |
+
|
| 167 |
+
# Ensure the 'along' dimension is preserved
|
| 168 |
+
if along not in result.dims:
|
| 169 |
+
raise ValueError(f"Section operation removed the '{along}' dimension")
|
| 170 |
+
|
| 171 |
+
# Update metadata
|
| 172 |
+
result.attrs = da.attrs.copy()
|
| 173 |
+
|
| 174 |
+
# Add section info to long_name
|
| 175 |
+
section_info = []
|
| 176 |
+
for dim, value in fixed.items():
|
| 177 |
+
if dim in da.dims:
|
| 178 |
+
if isinstance(value, (int, float)):
|
| 179 |
+
section_info.append(f"{dim}={value:.3f}")
|
| 180 |
+
else:
|
| 181 |
+
section_info.append(f"{dim}={value}")
|
| 182 |
+
|
| 183 |
+
if section_info:
|
| 184 |
+
long_name = result.attrs.get('long_name', result.name)
|
| 185 |
+
result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})"
|
| 186 |
+
|
| 187 |
+
return result
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray:
|
| 191 |
+
"""
|
| 192 |
+
Aggregate spatially (e.g., zonal mean).
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
da: Input DataArray
|
| 196 |
+
method: Aggregation method ('mean', 'sum', 'std')
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Spatially aggregated DataArray
|
| 200 |
+
"""
|
| 201 |
+
coords = identify_coordinates(da)
|
| 202 |
+
|
| 203 |
+
spatial_dims = []
|
| 204 |
+
if 'X' in coords:
|
| 205 |
+
spatial_dims.append(coords['X'])
|
| 206 |
+
if 'Y' in coords:
|
| 207 |
+
spatial_dims.append(coords['Y'])
|
| 208 |
+
|
| 209 |
+
if not spatial_dims:
|
| 210 |
+
raise ValueError("No spatial dimensions found for aggregation")
|
| 211 |
+
|
| 212 |
+
# Perform aggregation
|
| 213 |
+
if method == "mean":
|
| 214 |
+
result = da.mean(dim=spatial_dims)
|
| 215 |
+
elif method == "sum":
|
| 216 |
+
result = da.sum(dim=spatial_dims)
|
| 217 |
+
elif method == "std":
|
| 218 |
+
result = da.std(dim=spatial_dims)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"Unknown aggregation method: {method}")
|
| 221 |
+
|
| 222 |
+
# Update attributes
|
| 223 |
+
result.attrs = da.attrs.copy()
|
| 224 |
+
long_name = result.attrs.get('long_name', result.name)
|
| 225 |
+
result.attrs['long_name'] = f"{method.capitalize()} of {long_name}"
|
| 226 |
+
|
| 227 |
+
return result
|
tensorview/io.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""I/O operations for opening and managing datasets."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional, Dict, Any, List
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import xarray as xr
|
| 7 |
+
import fsspec
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class VariableSpec:
|
| 15 |
+
"""Variable specification with metadata."""
|
| 16 |
+
name: str
|
| 17 |
+
shape: tuple
|
| 18 |
+
dims: tuple
|
| 19 |
+
dtype: str
|
| 20 |
+
units: str
|
| 21 |
+
long_name: str
|
| 22 |
+
attrs: Dict[str, Any]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DatasetHandle:
|
| 26 |
+
"""Handle for opened datasets."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, dataset: xr.Dataset, uri: str, engine: str):
|
| 29 |
+
self.dataset = dataset
|
| 30 |
+
self.uri = uri
|
| 31 |
+
self.engine = engine
|
| 32 |
+
|
| 33 |
+
def close(self):
|
| 34 |
+
"""Close the dataset."""
|
| 35 |
+
if hasattr(self.dataset, 'close'):
|
| 36 |
+
self.dataset.close()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def detect_engine(uri: str) -> str:
|
| 40 |
+
"""Auto-detect the appropriate engine for a given URI/file."""
|
| 41 |
+
if uri.lower().endswith('.zarr') or 'zarr' in uri.lower():
|
| 42 |
+
return 'zarr'
|
| 43 |
+
elif uri.lower().endswith('.grib') or uri.lower().endswith('.grb'):
|
| 44 |
+
return 'cfgrib'
|
| 45 |
+
elif any(ext in uri.lower() for ext in ['.nc', '.netcdf', '.hdf', '.h5']):
|
| 46 |
+
# Try h5netcdf first, fallback to netcdf4
|
| 47 |
+
try:
|
| 48 |
+
import h5netcdf
|
| 49 |
+
return 'h5netcdf'
|
| 50 |
+
except ImportError:
|
| 51 |
+
return 'netcdf4'
|
| 52 |
+
else:
|
| 53 |
+
# Default fallback
|
| 54 |
+
try:
|
| 55 |
+
import h5netcdf
|
| 56 |
+
return 'h5netcdf'
|
| 57 |
+
except ImportError:
|
| 58 |
+
return 'netcdf4'
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def open_any(uri: str, engine: Optional[str] = None, chunks: str = "auto") -> DatasetHandle:
|
| 62 |
+
"""
|
| 63 |
+
Open a dataset from various sources (local, HTTP, S3, etc.).
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
uri: Path or URL to dataset
|
| 67 |
+
engine: Engine to use ('h5netcdf', 'netcdf4', 'cfgrib', 'zarr')
|
| 68 |
+
chunks: Chunking strategy for dask
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
DatasetHandle: Handle to the opened dataset
|
| 72 |
+
"""
|
| 73 |
+
if engine is None:
|
| 74 |
+
engine = detect_engine(uri)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
if engine == 'zarr':
|
| 78 |
+
# For Zarr stores
|
| 79 |
+
if uri.startswith('s3://'):
|
| 80 |
+
import s3fs
|
| 81 |
+
fs = s3fs.S3FileSystem(anon=True)
|
| 82 |
+
store = s3fs.S3Map(root=uri, s3=fs, check=False)
|
| 83 |
+
ds = xr.open_zarr(store, chunks=chunks)
|
| 84 |
+
else:
|
| 85 |
+
ds = xr.open_zarr(uri, chunks=chunks)
|
| 86 |
+
elif engine == 'cfgrib':
|
| 87 |
+
# For GRIB files
|
| 88 |
+
ds = xr.open_dataset(uri, engine='cfgrib', chunks=chunks)
|
| 89 |
+
else:
|
| 90 |
+
# For netCDF/HDF files
|
| 91 |
+
if uri.startswith(('http://', 'https://')):
|
| 92 |
+
# Remote files (including OPeNDAP)
|
| 93 |
+
ds = xr.open_dataset(uri, engine=engine, chunks=chunks)
|
| 94 |
+
elif uri.startswith('s3://'):
|
| 95 |
+
# S3 files
|
| 96 |
+
import s3fs
|
| 97 |
+
fs = s3fs.S3FileSystem(anon=True)
|
| 98 |
+
with fs.open(uri, 'rb') as f:
|
| 99 |
+
ds = xr.open_dataset(f, engine=engine, chunks=chunks)
|
| 100 |
+
else:
|
| 101 |
+
# Local files
|
| 102 |
+
ds = xr.open_dataset(uri, engine=engine, chunks=chunks)
|
| 103 |
+
|
| 104 |
+
return DatasetHandle(ds, uri, engine)
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
raise RuntimeError(f"Failed to open {uri} with engine {engine}: {str(e)}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def list_variables(handle: DatasetHandle) -> List[VariableSpec]:
|
| 111 |
+
"""
|
| 112 |
+
List all data variables in the dataset with their specifications.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
handle: Dataset handle
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
List of VariableSpec objects
|
| 119 |
+
"""
|
| 120 |
+
variables = []
|
| 121 |
+
|
| 122 |
+
for var_name, var in handle.dataset.data_vars.items():
|
| 123 |
+
# Skip coordinate variables and bounds
|
| 124 |
+
if var_name.endswith('_bounds') or var_name in handle.dataset.coords:
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
attrs = dict(var.attrs)
|
| 128 |
+
|
| 129 |
+
spec = VariableSpec(
|
| 130 |
+
name=var_name,
|
| 131 |
+
shape=var.shape,
|
| 132 |
+
dims=var.dims,
|
| 133 |
+
dtype=str(var.dtype),
|
| 134 |
+
units=attrs.get('units', ''),
|
| 135 |
+
long_name=attrs.get('long_name', var_name),
|
| 136 |
+
attrs=attrs
|
| 137 |
+
)
|
| 138 |
+
variables.append(spec)
|
| 139 |
+
|
| 140 |
+
return variables
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def get_dataarray(handle: DatasetHandle, var: str) -> xr.DataArray:
|
| 144 |
+
"""
|
| 145 |
+
Get a specific data array from the dataset.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
handle: Dataset handle
|
| 149 |
+
var: Variable name
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
xarray DataArray
|
| 153 |
+
"""
|
| 154 |
+
if var not in handle.dataset.data_vars:
|
| 155 |
+
raise ValueError(f"Variable '{var}' not found in dataset")
|
| 156 |
+
|
| 157 |
+
return handle.dataset[var]
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def close(handle: DatasetHandle):
|
| 161 |
+
"""Close a dataset handle."""
|
| 162 |
+
handle.close()
|
tensorview/plot.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plotting functions for 1D, 2D, and map visualizations."""
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
from typing import Optional, Dict, Any, Tuple, Literal
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib.colors as mcolors
|
| 9 |
+
from matplotlib.figure import Figure
|
| 10 |
+
from matplotlib.axes import Axes
|
| 11 |
+
import xarray as xr
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import cartopy.crs as ccrs
|
| 15 |
+
import cartopy.feature as cfeature
|
| 16 |
+
HAS_CARTOPY = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
HAS_CARTOPY = False
|
| 19 |
+
|
| 20 |
+
from .utils import identify_coordinates, get_crs, is_geographic, format_value
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def setup_matplotlib():
|
| 24 |
+
"""Setup matplotlib with non-interactive backend."""
|
| 25 |
+
plt.switch_backend('Agg')
|
| 26 |
+
plt.style.use('default')
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def plot_1d(da: xr.DataArray, x_dim: Optional[str] = None, **style) -> Figure:
|
| 30 |
+
"""
|
| 31 |
+
Create a 1D line plot.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
da: Input DataArray (should be 1D or have only one varying dimension)
|
| 35 |
+
x_dim: Dimension to use as x-axis (auto-detected if None)
|
| 36 |
+
**style: Style parameters (color, linewidth, etc.)
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
matplotlib Figure
|
| 40 |
+
"""
|
| 41 |
+
setup_matplotlib()
|
| 42 |
+
|
| 43 |
+
# Find the appropriate dimension for x-axis
|
| 44 |
+
if x_dim is None:
|
| 45 |
+
# Find the first dimension with more than 1 element
|
| 46 |
+
for dim in da.dims:
|
| 47 |
+
if da.sizes[dim] > 1:
|
| 48 |
+
x_dim = dim
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
if x_dim is None:
|
| 52 |
+
raise ValueError("No suitable dimension found for 1D plot")
|
| 53 |
+
|
| 54 |
+
if x_dim not in da.dims:
|
| 55 |
+
raise ValueError(f"Dimension '{x_dim}' not found in DataArray")
|
| 56 |
+
|
| 57 |
+
# Create the figure
|
| 58 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 59 |
+
|
| 60 |
+
# Get data for plotting
|
| 61 |
+
x_data = da.coords[x_dim]
|
| 62 |
+
y_data = da
|
| 63 |
+
|
| 64 |
+
# Plot the data
|
| 65 |
+
line_style = {
|
| 66 |
+
'color': style.get('color', 'blue'),
|
| 67 |
+
'linewidth': style.get('linewidth', 1.5),
|
| 68 |
+
'linestyle': style.get('linestyle', '-'),
|
| 69 |
+
'marker': style.get('marker', ''),
|
| 70 |
+
'markersize': style.get('markersize', 4),
|
| 71 |
+
'alpha': style.get('alpha', 1.0)
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
ax.plot(x_data, y_data, **line_style)
|
| 75 |
+
|
| 76 |
+
# Set labels
|
| 77 |
+
ax.set_xlabel(f"{x_dim} ({x_data.attrs.get('units', '')})")
|
| 78 |
+
ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
|
| 79 |
+
|
| 80 |
+
# Set title
|
| 81 |
+
title = da.attrs.get('long_name', da.name or 'Data')
|
| 82 |
+
ax.set_title(title)
|
| 83 |
+
|
| 84 |
+
# Add grid if requested
|
| 85 |
+
if style.get('grid', True):
|
| 86 |
+
ax.grid(True, alpha=0.3)
|
| 87 |
+
|
| 88 |
+
# Handle time axis formatting
|
| 89 |
+
if 'time' in x_dim.lower() or x_data.dtype.kind == 'M':
|
| 90 |
+
fig.autofmt_xdate()
|
| 91 |
+
|
| 92 |
+
plt.tight_layout()
|
| 93 |
+
return fig
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def plot_2d(da: xr.DataArray, kind: Literal["image", "contour"] = "image",
|
| 97 |
+
x_dim: Optional[str] = None, y_dim: Optional[str] = None, **style) -> Figure:
|
| 98 |
+
"""
|
| 99 |
+
Create a 2D plot (image or contour).
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
da: Input DataArray (should be 2D)
|
| 103 |
+
kind: Plot type ('image' or 'contour')
|
| 104 |
+
x_dim, y_dim: Dimensions to use for axes
|
| 105 |
+
**style: Style parameters
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
matplotlib Figure
|
| 109 |
+
"""
|
| 110 |
+
setup_matplotlib()
|
| 111 |
+
|
| 112 |
+
# Auto-detect dimensions if not provided
|
| 113 |
+
if x_dim is None or y_dim is None:
|
| 114 |
+
coords = identify_coordinates(da)
|
| 115 |
+
if x_dim is None:
|
| 116 |
+
x_dim = coords.get('X', da.dims[-1]) # Default to last dimension
|
| 117 |
+
if y_dim is None:
|
| 118 |
+
y_dim = coords.get('Y', da.dims[-2]) # Default to second-to-last dimension
|
| 119 |
+
|
| 120 |
+
if x_dim not in da.dims or y_dim not in da.dims:
|
| 121 |
+
raise ValueError(f"Dimensions {x_dim}, {y_dim} not found in DataArray")
|
| 122 |
+
|
| 123 |
+
# Transpose to get (y, x) order for plotting
|
| 124 |
+
da_plot = da.transpose(y_dim, x_dim)
|
| 125 |
+
|
| 126 |
+
# Create figure
|
| 127 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 128 |
+
|
| 129 |
+
# Get coordinates
|
| 130 |
+
x_coord = da.coords[x_dim]
|
| 131 |
+
y_coord = da.coords[y_dim]
|
| 132 |
+
|
| 133 |
+
# Set up colormap
|
| 134 |
+
cmap = style.get('cmap', 'viridis')
|
| 135 |
+
if isinstance(cmap, str):
|
| 136 |
+
cmap = plt.get_cmap(cmap)
|
| 137 |
+
|
| 138 |
+
# Set up normalization
|
| 139 |
+
vmin = style.get('vmin', float(da.min().values))
|
| 140 |
+
vmax = style.get('vmax', float(da.max().values))
|
| 141 |
+
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
|
| 142 |
+
|
| 143 |
+
if kind == "image":
|
| 144 |
+
# Use imshow for regular grids
|
| 145 |
+
im = ax.imshow(da_plot.values,
|
| 146 |
+
extent=[float(x_coord.min()), float(x_coord.max()),
|
| 147 |
+
float(y_coord.min()), float(y_coord.max())],
|
| 148 |
+
aspect='auto', origin='lower', cmap=cmap, norm=norm)
|
| 149 |
+
|
| 150 |
+
elif kind == "contour":
|
| 151 |
+
# Use contourf for contour plots
|
| 152 |
+
levels = style.get('levels', 20)
|
| 153 |
+
if isinstance(levels, int):
|
| 154 |
+
levels = np.linspace(vmin, vmax, levels)
|
| 155 |
+
|
| 156 |
+
X, Y = np.meshgrid(x_coord, y_coord)
|
| 157 |
+
im = ax.contourf(X, Y, da_plot.values, levels=levels, cmap=cmap, norm=norm)
|
| 158 |
+
|
| 159 |
+
# Add contour lines if requested
|
| 160 |
+
if style.get('contour_lines', False):
|
| 161 |
+
cs = ax.contour(X, Y, da_plot.values, levels=levels, colors='k', linewidths=0.5)
|
| 162 |
+
ax.clabel(cs, inline=True, fontsize=8)
|
| 163 |
+
|
| 164 |
+
# Add colorbar
|
| 165 |
+
if style.get('colorbar', True):
|
| 166 |
+
cbar = plt.colorbar(im, ax=ax)
|
| 167 |
+
cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
|
| 168 |
+
|
| 169 |
+
# Set labels
|
| 170 |
+
ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})")
|
| 171 |
+
ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})")
|
| 172 |
+
|
| 173 |
+
# Set title
|
| 174 |
+
title = da.attrs.get('long_name', da.name or 'Data')
|
| 175 |
+
ax.set_title(title)
|
| 176 |
+
|
| 177 |
+
plt.tight_layout()
|
| 178 |
+
return fig
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def plot_map(da: xr.DataArray, proj: str = "PlateCarree", **style) -> Figure:
|
| 182 |
+
"""
|
| 183 |
+
Create a map plot with cartopy.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
da: Input DataArray with geographic coordinates
|
| 187 |
+
proj: Map projection name
|
| 188 |
+
**style: Style parameters
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
matplotlib Figure
|
| 192 |
+
"""
|
| 193 |
+
if not HAS_CARTOPY:
|
| 194 |
+
raise ImportError("Cartopy is required for map plotting")
|
| 195 |
+
|
| 196 |
+
setup_matplotlib()
|
| 197 |
+
|
| 198 |
+
# Check if data is geographic
|
| 199 |
+
if not is_geographic(da):
|
| 200 |
+
raise ValueError("DataArray does not appear to have geographic coordinates")
|
| 201 |
+
|
| 202 |
+
# Get coordinate information
|
| 203 |
+
coords = identify_coordinates(da)
|
| 204 |
+
if 'X' not in coords or 'Y' not in coords:
|
| 205 |
+
raise ValueError("Could not identify longitude/latitude coordinates")
|
| 206 |
+
|
| 207 |
+
lon_dim = coords['X']
|
| 208 |
+
lat_dim = coords['Y']
|
| 209 |
+
|
| 210 |
+
# Set up projection
|
| 211 |
+
proj_map = {
|
| 212 |
+
'PlateCarree': ccrs.PlateCarree(),
|
| 213 |
+
'Robinson': ccrs.Robinson(),
|
| 214 |
+
'Mollweide': ccrs.Mollweide(),
|
| 215 |
+
'Orthographic': ccrs.Orthographic(),
|
| 216 |
+
'NorthPolarStereo': ccrs.NorthPolarStereo(),
|
| 217 |
+
'SouthPolarStereo': ccrs.SouthPolarStereo(),
|
| 218 |
+
'Miller': ccrs.Miller(),
|
| 219 |
+
'InterruptedGoodeHomolosine': ccrs.InterruptedGoodeHomolosine()
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
if proj not in proj_map:
|
| 223 |
+
proj = 'PlateCarree' # Default fallback
|
| 224 |
+
|
| 225 |
+
projection = proj_map[proj]
|
| 226 |
+
|
| 227 |
+
# Create figure with cartopy
|
| 228 |
+
fig, ax = plt.subplots(figsize=(12, 8),
|
| 229 |
+
subplot_kw={'projection': projection})
|
| 230 |
+
|
| 231 |
+
# Transpose to get (lat, lon) order
|
| 232 |
+
da_plot = da.transpose(lat_dim, lon_dim)
|
| 233 |
+
|
| 234 |
+
# Get coordinates
|
| 235 |
+
lons = da.coords[lon_dim].values
|
| 236 |
+
lats = da.coords[lat_dim].values
|
| 237 |
+
|
| 238 |
+
# Set up colormap and normalization
|
| 239 |
+
cmap = style.get('cmap', 'viridis')
|
| 240 |
+
if isinstance(cmap, str):
|
| 241 |
+
cmap = plt.get_cmap(cmap)
|
| 242 |
+
|
| 243 |
+
vmin = style.get('vmin', float(da.min().values))
|
| 244 |
+
vmax = style.get('vmax', float(da.max().values))
|
| 245 |
+
|
| 246 |
+
# Create plot
|
| 247 |
+
plot_type = style.get('plot_type', 'pcolormesh')
|
| 248 |
+
|
| 249 |
+
if plot_type == 'contourf':
|
| 250 |
+
levels = style.get('levels', 20)
|
| 251 |
+
if isinstance(levels, int):
|
| 252 |
+
levels = np.linspace(vmin, vmax, levels)
|
| 253 |
+
im = ax.contourf(lons, lats, da_plot.values, levels=levels,
|
| 254 |
+
cmap=cmap, transform=ccrs.PlateCarree())
|
| 255 |
+
else:
|
| 256 |
+
im = ax.pcolormesh(lons, lats, da_plot.values, cmap=cmap,
|
| 257 |
+
transform=ccrs.PlateCarree(),
|
| 258 |
+
vmin=vmin, vmax=vmax, shading='auto')
|
| 259 |
+
|
| 260 |
+
# Add map features
|
| 261 |
+
if style.get('coastlines', True):
|
| 262 |
+
ax.coastlines(resolution='50m', color='black', linewidth=0.5)
|
| 263 |
+
|
| 264 |
+
if style.get('borders', False):
|
| 265 |
+
ax.add_feature(cfeature.BORDERS, linewidth=0.5)
|
| 266 |
+
|
| 267 |
+
if style.get('ocean', False):
|
| 268 |
+
ax.add_feature(cfeature.OCEAN, color='lightblue', alpha=0.5)
|
| 269 |
+
|
| 270 |
+
if style.get('land', False):
|
| 271 |
+
ax.add_feature(cfeature.LAND, color='lightgray', alpha=0.5)
|
| 272 |
+
|
| 273 |
+
# Add gridlines
|
| 274 |
+
if style.get('gridlines', True):
|
| 275 |
+
gl = ax.gridlines(draw_labels=True, alpha=0.5)
|
| 276 |
+
gl.top_labels = False
|
| 277 |
+
gl.right_labels = False
|
| 278 |
+
|
| 279 |
+
# Set extent if specified
|
| 280 |
+
if 'extent' in style:
|
| 281 |
+
ax.set_extent(style['extent'], crs=ccrs.PlateCarree())
|
| 282 |
+
else:
|
| 283 |
+
ax.set_global()
|
| 284 |
+
|
| 285 |
+
# Add colorbar
|
| 286 |
+
if style.get('colorbar', True):
|
| 287 |
+
cbar = plt.colorbar(im, ax=ax, orientation='horizontal',
|
| 288 |
+
pad=0.05, shrink=0.8)
|
| 289 |
+
cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
|
| 290 |
+
|
| 291 |
+
# Set title
|
| 292 |
+
title = da.attrs.get('long_name', da.name or 'Data')
|
| 293 |
+
ax.set_title(title, pad=20)
|
| 294 |
+
|
| 295 |
+
plt.tight_layout()
|
| 296 |
+
return fig
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def export_fig(fig: Figure, fmt: Literal["png", "svg", "pdf"] = "png",
|
| 300 |
+
dpi: int = 150, out_path: Optional[str] = None) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Export a figure to file or return as bytes.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
fig: matplotlib Figure
|
| 306 |
+
fmt: Output format
|
| 307 |
+
dpi: Resolution for raster formats
|
| 308 |
+
out_path: Output file path (if None, returns bytes)
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
File path or bytes
|
| 312 |
+
"""
|
| 313 |
+
if out_path is None:
|
| 314 |
+
# Return as bytes
|
| 315 |
+
buf = io.BytesIO()
|
| 316 |
+
fig.savefig(buf, format=fmt, dpi=dpi, bbox_inches='tight')
|
| 317 |
+
buf.seek(0)
|
| 318 |
+
return buf.getvalue()
|
| 319 |
+
else:
|
| 320 |
+
# Save to file
|
| 321 |
+
fig.savefig(out_path, format=fmt, dpi=dpi, bbox_inches='tight')
|
| 322 |
+
return out_path
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def create_subplot_figure(n_plots: int, ncols: int = 2) -> Tuple[Figure, np.ndarray]:
|
| 326 |
+
"""Create a figure with multiple subplots."""
|
| 327 |
+
nrows = (n_plots + ncols - 1) // ncols
|
| 328 |
+
fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 4*nrows))
|
| 329 |
+
|
| 330 |
+
if n_plots == 1:
|
| 331 |
+
axes = np.array([axes])
|
| 332 |
+
elif nrows == 1:
|
| 333 |
+
axes = axes.reshape(1, -1)
|
| 334 |
+
|
| 335 |
+
# Hide unused subplots
|
| 336 |
+
for i in range(n_plots, nrows * ncols):
|
| 337 |
+
axes.flat[i].set_visible(False)
|
| 338 |
+
|
| 339 |
+
return fig, axes
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def add_statistics_text(ax: Axes, da: xr.DataArray, x: float = 0.02, y: float = 0.98):
|
| 343 |
+
"""Add statistics text to a plot."""
|
| 344 |
+
stats = [
|
| 345 |
+
f"Min: {float(da.min().values):.3g}",
|
| 346 |
+
f"Max: {float(da.max().values):.3g}",
|
| 347 |
+
f"Mean: {float(da.mean().values):.3g}",
|
| 348 |
+
f"Std: {float(da.std().values):.3g}"
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
text = '\n'.join(stats)
|
| 352 |
+
ax.text(x, y, text, transform=ax.transAxes,
|
| 353 |
+
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
|
| 354 |
+
verticalalignment='top', fontsize=8)
|
tensorview/state.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""View state serialization and deserialization for saving/loading plot configurations."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import Dict, Any, Optional, List
|
| 5 |
+
from dataclasses import dataclass, asdict
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import orjson
|
| 8 |
+
from pydantic import BaseModel, Field, validator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PlotConfig(BaseModel):
|
| 12 |
+
"""Plot configuration schema."""
|
| 13 |
+
plot_type: str = Field(..., description="Type of plot (1d, 2d, map)")
|
| 14 |
+
x_dim: Optional[str] = Field(None, description="X-axis dimension")
|
| 15 |
+
y_dim: Optional[str] = Field(None, description="Y-axis dimension")
|
| 16 |
+
projection: str = Field("PlateCarree", description="Map projection")
|
| 17 |
+
colormap: str = Field("viridis", description="Colormap name")
|
| 18 |
+
vmin: Optional[float] = Field(None, description="Color scale minimum")
|
| 19 |
+
vmax: Optional[float] = Field(None, description="Color scale maximum")
|
| 20 |
+
levels: Optional[List[float]] = Field(None, description="Contour levels")
|
| 21 |
+
style: Dict[str, Any] = Field(default_factory=dict, description="Additional style parameters")
|
| 22 |
+
|
| 23 |
+
@validator('plot_type')
|
| 24 |
+
def validate_plot_type(cls, v):
|
| 25 |
+
allowed = ['1d', '2d', 'map', 'contour']
|
| 26 |
+
if v not in allowed:
|
| 27 |
+
raise ValueError(f"plot_type must be one of {allowed}")
|
| 28 |
+
return v
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DataConfig(BaseModel):
|
| 32 |
+
"""Data configuration schema."""
|
| 33 |
+
uri: str = Field(..., description="Data source URI")
|
| 34 |
+
engine: Optional[str] = Field(None, description="Data engine")
|
| 35 |
+
variable_a: str = Field(..., description="Primary variable")
|
| 36 |
+
variable_b: Optional[str] = Field(None, description="Secondary variable for operations")
|
| 37 |
+
operation: str = Field("none", description="Operation between variables")
|
| 38 |
+
selections: Dict[str, Any] = Field(default_factory=dict, description="Dimension selections")
|
| 39 |
+
|
| 40 |
+
@validator('operation')
|
| 41 |
+
def validate_operation(cls, v):
|
| 42 |
+
allowed = ['none', 'sum', 'avg', 'diff']
|
| 43 |
+
if v not in allowed:
|
| 44 |
+
raise ValueError(f"operation must be one of {allowed}")
|
| 45 |
+
return v
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ViewState(BaseModel):
|
| 49 |
+
"""Complete view state schema."""
|
| 50 |
+
version: str = Field("1.0", description="State schema version")
|
| 51 |
+
created: str = Field(default_factory=lambda: datetime.utcnow().isoformat(), description="Creation timestamp")
|
| 52 |
+
title: str = Field("", description="User-defined title")
|
| 53 |
+
description: str = Field("", description="User-defined description")
|
| 54 |
+
data_config: DataConfig = Field(..., description="Data configuration")
|
| 55 |
+
plot_config: PlotConfig = Field(..., description="Plot configuration")
|
| 56 |
+
animation: Optional[Dict[str, Any]] = Field(None, description="Animation settings")
|
| 57 |
+
exports: List[str] = Field(default_factory=list, description="Export history")
|
| 58 |
+
|
| 59 |
+
class Config:
|
| 60 |
+
extra = "allow" # Allow additional fields for extensibility
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def create_view_state(data_config: Dict[str, Any], plot_config: Dict[str, Any],
|
| 64 |
+
title: str = "", description: str = "",
|
| 65 |
+
animation: Optional[Dict[str, Any]] = None) -> ViewState:
|
| 66 |
+
"""
|
| 67 |
+
Create a new view state object.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
data_config: Data configuration dictionary
|
| 71 |
+
plot_config: Plot configuration dictionary
|
| 72 |
+
title: Optional title
|
| 73 |
+
description: Optional description
|
| 74 |
+
animation: Optional animation settings
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
ViewState object
|
| 78 |
+
"""
|
| 79 |
+
data_cfg = DataConfig(**data_config)
|
| 80 |
+
plot_cfg = PlotConfig(**plot_config)
|
| 81 |
+
|
| 82 |
+
return ViewState(
|
| 83 |
+
title=title,
|
| 84 |
+
description=description,
|
| 85 |
+
data_config=data_cfg,
|
| 86 |
+
plot_config=plot_cfg,
|
| 87 |
+
animation=animation
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def dump_state(state: ViewState) -> str:
|
| 92 |
+
"""
|
| 93 |
+
Serialize a view state to JSON string.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
state: ViewState object
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
JSON string
|
| 100 |
+
"""
|
| 101 |
+
# Use orjson for better performance and datetime handling
|
| 102 |
+
return orjson.dumps(state.dict(), option=orjson.OPT_INDENT_2).decode('utf-8')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_state(state_json: str) -> ViewState:
|
| 106 |
+
"""
|
| 107 |
+
Deserialize a view state from JSON string.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
state_json: JSON string
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
ViewState object
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
data = orjson.loads(state_json)
|
| 117 |
+
return ViewState(**data)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
raise ValueError(f"Failed to parse view state: {str(e)}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def save_state_file(state: ViewState, filepath: str) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Save view state to a file.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
state: ViewState object
|
| 128 |
+
filepath: Output file path
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
File path
|
| 132 |
+
"""
|
| 133 |
+
state_json = dump_state(state)
|
| 134 |
+
|
| 135 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 136 |
+
f.write(state_json)
|
| 137 |
+
|
| 138 |
+
return filepath
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def load_state_file(filepath: str) -> ViewState:
|
| 142 |
+
"""
|
| 143 |
+
Load view state from a file.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
filepath: Input file path
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
ViewState object
|
| 150 |
+
"""
|
| 151 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 152 |
+
state_json = f.read()
|
| 153 |
+
|
| 154 |
+
return load_state(state_json)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def merge_states(base_state: ViewState, updates: Dict[str, Any]) -> ViewState:
|
| 158 |
+
"""
|
| 159 |
+
Merge updates into a base state.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
base_state: Base ViewState object
|
| 163 |
+
updates: Dictionary of updates
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
New ViewState object with updates applied
|
| 167 |
+
"""
|
| 168 |
+
# Convert to dict, apply updates, and create new state
|
| 169 |
+
state_dict = base_state.dict()
|
| 170 |
+
|
| 171 |
+
# Deep merge for nested dictionaries
|
| 172 |
+
def deep_merge(base_dict, update_dict):
|
| 173 |
+
for key, value in update_dict.items():
|
| 174 |
+
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
|
| 175 |
+
deep_merge(base_dict[key], value)
|
| 176 |
+
else:
|
| 177 |
+
base_dict[key] = value
|
| 178 |
+
|
| 179 |
+
deep_merge(state_dict, updates)
|
| 180 |
+
return ViewState(**state_dict)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def validate_state_compatibility(state: ViewState) -> List[str]:
|
| 184 |
+
"""
|
| 185 |
+
Check state compatibility and return any warnings.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
state: ViewState object
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
List of warning messages
|
| 192 |
+
"""
|
| 193 |
+
warnings = []
|
| 194 |
+
|
| 195 |
+
# Check version compatibility
|
| 196 |
+
current_version = "1.0"
|
| 197 |
+
if state.version != current_version:
|
| 198 |
+
warnings.append(f"State version {state.version} may not be fully compatible with current version {current_version}")
|
| 199 |
+
|
| 200 |
+
# Check plot type and dimension compatibility
|
| 201 |
+
plot_config = state.plot_config
|
| 202 |
+
data_config = state.data_config
|
| 203 |
+
|
| 204 |
+
if plot_config.plot_type == "map":
|
| 205 |
+
if not plot_config.x_dim or not plot_config.y_dim:
|
| 206 |
+
warnings.append("Map plots require both x_dim and y_dim to be specified")
|
| 207 |
+
|
| 208 |
+
elif plot_config.plot_type == "2d":
|
| 209 |
+
if not plot_config.x_dim or not plot_config.y_dim:
|
| 210 |
+
warnings.append("2D plots require both x_dim and y_dim to be specified")
|
| 211 |
+
|
| 212 |
+
elif plot_config.plot_type == "1d":
|
| 213 |
+
if not plot_config.x_dim:
|
| 214 |
+
warnings.append("1D plots require x_dim to be specified")
|
| 215 |
+
|
| 216 |
+
# Check operation compatibility
|
| 217 |
+
if data_config.operation != "none" and not data_config.variable_b:
|
| 218 |
+
warnings.append(f"Operation '{data_config.operation}' requires variable_b to be specified")
|
| 219 |
+
|
| 220 |
+
return warnings
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def create_default_state(uri: str, variable: str) -> ViewState:
|
| 224 |
+
"""
|
| 225 |
+
Create a default view state for a given data source and variable.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
uri: Data source URI
|
| 229 |
+
variable: Variable name
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Default ViewState object
|
| 233 |
+
"""
|
| 234 |
+
data_config = {
|
| 235 |
+
'uri': uri,
|
| 236 |
+
'variable_a': variable,
|
| 237 |
+
'operation': 'none',
|
| 238 |
+
'selections': {}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
plot_config = {
|
| 242 |
+
'plot_type': '2d',
|
| 243 |
+
'colormap': 'viridis',
|
| 244 |
+
'style': {
|
| 245 |
+
'colorbar': True,
|
| 246 |
+
'grid': True
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
return create_view_state(data_config, plot_config, title=f"View of {variable}")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def export_state_summary(state: ViewState) -> Dict[str, Any]:
|
| 254 |
+
"""
|
| 255 |
+
Create a human-readable summary of a view state.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
state: ViewState object
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Summary dictionary
|
| 262 |
+
"""
|
| 263 |
+
summary = {
|
| 264 |
+
'title': state.title or "Untitled View",
|
| 265 |
+
'created': state.created,
|
| 266 |
+
'data_source': state.data_config.uri,
|
| 267 |
+
'primary_variable': state.data_config.variable_a,
|
| 268 |
+
'plot_type': state.plot_config.plot_type,
|
| 269 |
+
'has_secondary_variable': state.data_config.variable_b is not None,
|
| 270 |
+
'operation': state.data_config.operation,
|
| 271 |
+
'colormap': state.plot_config.colormap,
|
| 272 |
+
'has_animation': state.animation is not None,
|
| 273 |
+
'export_count': len(state.exports)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
# Add dimension info
|
| 277 |
+
selections = state.data_config.selections
|
| 278 |
+
if selections:
|
| 279 |
+
summary['fixed_dimensions'] = list(selections.keys())
|
| 280 |
+
summary['selection_count'] = len(selections)
|
| 281 |
+
|
| 282 |
+
# Add plot-specific info
|
| 283 |
+
if state.plot_config.plot_type == "map":
|
| 284 |
+
summary['projection'] = state.plot_config.projection
|
| 285 |
+
|
| 286 |
+
return summary
|
tensorview/utils.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for CF conventions and coordinate system helpers."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict, List, Optional, Tuple, Any
|
| 5 |
+
import numpy as np
|
| 6 |
+
import xarray as xr
|
| 7 |
+
from pyproj import CRS
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]:
|
| 11 |
+
"""
|
| 12 |
+
Guess the CF axis type (X, Y, Z, T) for a coordinate.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
da: DataArray containing the coordinate
|
| 16 |
+
coord_name: Name of the coordinate
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
CF axis type ('X', 'Y', 'Z', 'T') or None
|
| 20 |
+
"""
|
| 21 |
+
if coord_name not in da.coords:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
coord = da.coords[coord_name]
|
| 25 |
+
attrs = coord.attrs
|
| 26 |
+
name_lower = coord_name.lower()
|
| 27 |
+
|
| 28 |
+
# Check explicit axis attribute
|
| 29 |
+
if 'axis' in attrs:
|
| 30 |
+
return attrs['axis'].upper()
|
| 31 |
+
|
| 32 |
+
# Check standard_name
|
| 33 |
+
standard_name = attrs.get('standard_name', '').lower()
|
| 34 |
+
if standard_name in ['longitude', 'projection_x_coordinate']:
|
| 35 |
+
return 'X'
|
| 36 |
+
elif standard_name in ['latitude', 'projection_y_coordinate']:
|
| 37 |
+
return 'Y'
|
| 38 |
+
elif standard_name in ['time']:
|
| 39 |
+
return 'T'
|
| 40 |
+
elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure':
|
| 41 |
+
return 'Z'
|
| 42 |
+
|
| 43 |
+
# Check coordinate name patterns
|
| 44 |
+
if any(pattern in name_lower for pattern in ['lon', 'x']):
|
| 45 |
+
return 'X'
|
| 46 |
+
elif any(pattern in name_lower for pattern in ['lat', 'y']):
|
| 47 |
+
return 'Y'
|
| 48 |
+
elif any(pattern in name_lower for pattern in ['time', 't']):
|
| 49 |
+
return 'T'
|
| 50 |
+
elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']):
|
| 51 |
+
return 'Z'
|
| 52 |
+
|
| 53 |
+
# Check units
|
| 54 |
+
units = attrs.get('units', '').lower()
|
| 55 |
+
if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']):
|
| 56 |
+
return 'X'
|
| 57 |
+
elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']):
|
| 58 |
+
return 'Y'
|
| 59 |
+
elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']):
|
| 60 |
+
return 'T'
|
| 61 |
+
elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']):
|
| 62 |
+
return 'Z'
|
| 63 |
+
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def identify_coordinates(da: xr.DataArray) -> Dict[str, str]:
|
| 68 |
+
"""
|
| 69 |
+
Identify coordinate types in a DataArray.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
da: Input DataArray
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dictionary mapping axis type to coordinate name
|
| 76 |
+
"""
|
| 77 |
+
coords = {}
|
| 78 |
+
|
| 79 |
+
for coord_name in da.dims:
|
| 80 |
+
axis = guess_cf_axis(da, coord_name)
|
| 81 |
+
if axis:
|
| 82 |
+
coords[axis] = coord_name
|
| 83 |
+
|
| 84 |
+
return coords
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_crs(da: xr.DataArray) -> Optional[CRS]:
|
| 88 |
+
"""
|
| 89 |
+
Extract CRS information from a DataArray.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
da: Input DataArray
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
pyproj CRS object or None
|
| 96 |
+
"""
|
| 97 |
+
# Check for grid_mapping attribute
|
| 98 |
+
grid_mapping = da.attrs.get('grid_mapping')
|
| 99 |
+
if grid_mapping and grid_mapping in da.coords:
|
| 100 |
+
gm_var = da.coords[grid_mapping]
|
| 101 |
+
|
| 102 |
+
# Try to construct CRS from grid mapping attributes
|
| 103 |
+
try:
|
| 104 |
+
crs_attrs = dict(gm_var.attrs)
|
| 105 |
+
return CRS.from_cf(crs_attrs)
|
| 106 |
+
except:
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
# Check for crs coordinate
|
| 110 |
+
if 'crs' in da.coords:
|
| 111 |
+
try:
|
| 112 |
+
return CRS.from_cf(dict(da.coords['crs'].attrs))
|
| 113 |
+
except:
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
# Check for spatial_ref coordinate (common in rioxarray)
|
| 117 |
+
if 'spatial_ref' in da.coords:
|
| 118 |
+
try:
|
| 119 |
+
spatial_ref = da.coords['spatial_ref']
|
| 120 |
+
if hasattr(spatial_ref, 'spatial_ref'):
|
| 121 |
+
return CRS.from_wkt(spatial_ref.spatial_ref)
|
| 122 |
+
elif 'crs_wkt' in spatial_ref.attrs:
|
| 123 |
+
return CRS.from_wkt(spatial_ref.attrs['crs_wkt'])
|
| 124 |
+
except:
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
# Default to geographic CRS if we have lat/lon
|
| 128 |
+
coords = identify_coordinates(da)
|
| 129 |
+
if 'X' in coords and 'Y' in coords:
|
| 130 |
+
x_coord = da.coords[coords['X']]
|
| 131 |
+
y_coord = da.coords[coords['Y']]
|
| 132 |
+
|
| 133 |
+
# Check if coordinates look like geographic
|
| 134 |
+
x_range = float(x_coord.max()) - float(x_coord.min())
|
| 135 |
+
y_range = float(y_coord.max()) - float(y_coord.min())
|
| 136 |
+
|
| 137 |
+
if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90:
|
| 138 |
+
return CRS.from_epsg(4326) # WGS84
|
| 139 |
+
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def is_geographic(da: xr.DataArray) -> bool:
|
| 144 |
+
"""Check if DataArray uses geographic coordinates."""
|
| 145 |
+
crs = get_crs(da)
|
| 146 |
+
if crs:
|
| 147 |
+
return crs.is_geographic
|
| 148 |
+
|
| 149 |
+
# Fallback: check coordinate ranges
|
| 150 |
+
coords = identify_coordinates(da)
|
| 151 |
+
if 'X' in coords and 'Y' in coords:
|
| 152 |
+
x_coord = da.coords[coords['X']]
|
| 153 |
+
y_coord = da.coords[coords['Y']]
|
| 154 |
+
|
| 155 |
+
x_range = float(x_coord.max()) - float(x_coord.min())
|
| 156 |
+
y_range = float(y_coord.max()) - float(y_coord.min())
|
| 157 |
+
|
| 158 |
+
return (-180 <= x_coord.min() <= x_coord.max() <= 360 and
|
| 159 |
+
-90 <= y_coord.min() <= y_coord.max() <= 90)
|
| 160 |
+
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray:
|
| 165 |
+
"""
|
| 166 |
+
Ensure longitude coordinates are in the specified range.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
da: Input DataArray
|
| 170 |
+
range_type: '180' for [-180, 180] or '360' for [0, 360]
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
DataArray with adjusted longitude coordinates
|
| 174 |
+
"""
|
| 175 |
+
coords = identify_coordinates(da)
|
| 176 |
+
if 'X' not in coords:
|
| 177 |
+
return da
|
| 178 |
+
|
| 179 |
+
x_coord = coords['X']
|
| 180 |
+
da_copy = da.copy()
|
| 181 |
+
|
| 182 |
+
if range_type == '180':
|
| 183 |
+
# Convert to [-180, 180]
|
| 184 |
+
da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180
|
| 185 |
+
elif range_type == '360':
|
| 186 |
+
# Convert to [0, 360]
|
| 187 |
+
da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360
|
| 188 |
+
|
| 189 |
+
# Sort by longitude if needed
|
| 190 |
+
if x_coord in da_copy.dims:
|
| 191 |
+
da_copy = da_copy.sortby(x_coord)
|
| 192 |
+
|
| 193 |
+
return da_copy
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]:
|
| 197 |
+
"""Get time bounds from a DataArray."""
|
| 198 |
+
coords = identify_coordinates(da)
|
| 199 |
+
if 'T' not in coords:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
time_coord = da.coords[coords['T']]
|
| 203 |
+
return (time_coord.min().values, time_coord.max().values)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def format_value(value: Any, coord_name: str = '') -> str:
|
| 207 |
+
"""Format a coordinate value for display."""
|
| 208 |
+
if isinstance(value, np.datetime64):
|
| 209 |
+
return str(value)[:19] # Remove nanoseconds
|
| 210 |
+
elif isinstance(value, (int, np.integer)):
|
| 211 |
+
return str(value)
|
| 212 |
+
elif isinstance(value, (float, np.floating)):
|
| 213 |
+
if 'time' in coord_name.lower():
|
| 214 |
+
return f"{value:.1f}"
|
| 215 |
+
else:
|
| 216 |
+
return f"{value:.3f}"
|
| 217 |
+
else:
|
| 218 |
+
return str(value)
|
tests/test_io.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for I/O operations."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import numpy as np
|
| 5 |
+
import xarray as xr
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
from tensorview.io import open_any, list_variables, get_dataarray, detect_engine
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_sample_netcdf():
|
| 12 |
+
"""Create a sample NetCDF file for testing."""
|
| 13 |
+
# Create sample data
|
| 14 |
+
lons = np.arange(-180, 180, 2.5)
|
| 15 |
+
lats = np.arange(-90, 90, 2.5)
|
| 16 |
+
times = np.arange(0, 10)
|
| 17 |
+
|
| 18 |
+
lon_grid, lat_grid = np.meshgrid(lons, lats)
|
| 19 |
+
|
| 20 |
+
# Create sample temperature data
|
| 21 |
+
temp_data = np.random.randn(len(times), len(lats), len(lons)) + 20
|
| 22 |
+
|
| 23 |
+
# Create xarray Dataset
|
| 24 |
+
ds = xr.Dataset({
|
| 25 |
+
'temperature': (['time', 'lat', 'lon'], temp_data, {
|
| 26 |
+
'units': 'degrees_C',
|
| 27 |
+
'long_name': 'Temperature',
|
| 28 |
+
'standard_name': 'air_temperature'
|
| 29 |
+
})
|
| 30 |
+
}, coords={
|
| 31 |
+
'time': ('time', times, {'units': 'days since 2000-01-01'}),
|
| 32 |
+
'lat': ('lat', lats, {'units': 'degrees_north', 'long_name': 'Latitude'}),
|
| 33 |
+
'lon': ('lon', lons, {'units': 'degrees_east', 'long_name': 'Longitude'})
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
# Save to temporary file
|
| 37 |
+
temp_file = tempfile.NamedTemporaryFile(suffix='.nc', delete=False)
|
| 38 |
+
ds.to_netcdf(temp_file.name)
|
| 39 |
+
temp_file.close()
|
| 40 |
+
|
| 41 |
+
return temp_file.name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_detect_engine():
|
| 45 |
+
"""Test engine detection."""
|
| 46 |
+
assert detect_engine('test.nc') == 'h5netcdf'
|
| 47 |
+
assert detect_engine('test.grib') == 'cfgrib'
|
| 48 |
+
assert detect_engine('test.zarr') == 'zarr'
|
| 49 |
+
assert detect_engine('test.h5') == 'h5netcdf'
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_open_netcdf():
|
| 53 |
+
"""Test opening NetCDF files."""
|
| 54 |
+
nc_file = create_sample_netcdf()
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
# Test opening
|
| 58 |
+
handle = open_any(nc_file)
|
| 59 |
+
assert handle is not None
|
| 60 |
+
assert handle.engine in ['h5netcdf', 'netcdf4']
|
| 61 |
+
|
| 62 |
+
# Test variable listing
|
| 63 |
+
variables = list_variables(handle)
|
| 64 |
+
assert len(variables) == 1
|
| 65 |
+
assert variables[0].name == 'temperature'
|
| 66 |
+
assert variables[0].units == 'degrees_C'
|
| 67 |
+
|
| 68 |
+
# Test getting data array
|
| 69 |
+
da = get_dataarray(handle, 'temperature')
|
| 70 |
+
assert da.name == 'temperature'
|
| 71 |
+
assert len(da.dims) == 3
|
| 72 |
+
assert 'time' in da.dims
|
| 73 |
+
assert 'lat' in da.dims
|
| 74 |
+
assert 'lon' in da.dims
|
| 75 |
+
|
| 76 |
+
# Clean up
|
| 77 |
+
handle.close()
|
| 78 |
+
|
| 79 |
+
finally:
|
| 80 |
+
os.unlink(nc_file)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def test_invalid_file():
|
| 84 |
+
"""Test error handling for invalid files."""
|
| 85 |
+
with pytest.raises(RuntimeError):
|
| 86 |
+
open_any('nonexistent_file.nc')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_invalid_variable():
|
| 90 |
+
"""Test error handling for invalid variables."""
|
| 91 |
+
nc_file = create_sample_netcdf()
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
handle = open_any(nc_file)
|
| 95 |
+
|
| 96 |
+
with pytest.raises(ValueError):
|
| 97 |
+
get_dataarray(handle, 'nonexistent_variable')
|
| 98 |
+
|
| 99 |
+
handle.close()
|
| 100 |
+
|
| 101 |
+
finally:
|
| 102 |
+
os.unlink(nc_file)
|
tests/test_plot.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for plotting functionality."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import numpy as np
|
| 5 |
+
import xarray as xr
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from tensorview.plot import plot_1d, plot_2d, setup_matplotlib
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_sample_data():
|
| 11 |
+
"""Create sample data for testing."""
|
| 12 |
+
# 1D data
|
| 13 |
+
x = np.linspace(0, 10, 100)
|
| 14 |
+
y = np.sin(x)
|
| 15 |
+
da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'],
|
| 16 |
+
attrs={'units': 'm/s', 'long_name': 'Sine Wave'})
|
| 17 |
+
|
| 18 |
+
# 2D data
|
| 19 |
+
lons = np.linspace(-10, 10, 20)
|
| 20 |
+
lats = np.linspace(-10, 10, 15)
|
| 21 |
+
lon_grid, lat_grid = np.meshgrid(lons, lats)
|
| 22 |
+
temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1
|
| 23 |
+
|
| 24 |
+
da_2d = xr.DataArray(temp_data,
|
| 25 |
+
coords={'lat': lats, 'lon': lons},
|
| 26 |
+
dims=['lat', 'lon'],
|
| 27 |
+
attrs={'units': 'degrees_C', 'long_name': 'Temperature'})
|
| 28 |
+
|
| 29 |
+
return da_1d, da_2d
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_setup_matplotlib():
|
| 33 |
+
"""Test matplotlib setup."""
|
| 34 |
+
setup_matplotlib()
|
| 35 |
+
assert plt.get_backend() == 'Agg'
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_plot_1d():
|
| 39 |
+
"""Test 1D plotting."""
|
| 40 |
+
da_1d, _ = create_sample_data()
|
| 41 |
+
|
| 42 |
+
fig = plot_1d(da_1d)
|
| 43 |
+
assert fig is not None
|
| 44 |
+
assert len(fig.axes) == 1
|
| 45 |
+
|
| 46 |
+
ax = fig.axes[0]
|
| 47 |
+
assert len(ax.lines) == 1
|
| 48 |
+
assert ax.get_xlabel() == 'x ()'
|
| 49 |
+
assert 'Sine Wave' in ax.get_title()
|
| 50 |
+
|
| 51 |
+
plt.close(fig)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_plot_2d():
|
| 55 |
+
"""Test 2D plotting."""
|
| 56 |
+
_, da_2d = create_sample_data()
|
| 57 |
+
|
| 58 |
+
# Test image plot
|
| 59 |
+
fig = plot_2d(da_2d, kind="image")
|
| 60 |
+
assert fig is not None
|
| 61 |
+
assert len(fig.axes) >= 1 # Plot axis + possibly colorbar axis
|
| 62 |
+
|
| 63 |
+
ax = fig.axes[0]
|
| 64 |
+
assert ax.get_xlabel() == 'lon ()'
|
| 65 |
+
assert ax.get_ylabel() == 'lat ()'
|
| 66 |
+
assert 'Temperature' in ax.get_title()
|
| 67 |
+
|
| 68 |
+
plt.close(fig)
|
| 69 |
+
|
| 70 |
+
# Test contour plot
|
| 71 |
+
fig = plot_2d(da_2d, kind="contour")
|
| 72 |
+
assert fig is not None
|
| 73 |
+
plt.close(fig)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_plot_styling():
|
| 77 |
+
"""Test plot styling options."""
|
| 78 |
+
da_1d, da_2d = create_sample_data()
|
| 79 |
+
|
| 80 |
+
# Test 1D styling
|
| 81 |
+
fig = plot_1d(da_1d, color='red', linewidth=2, grid=False)
|
| 82 |
+
ax = fig.axes[0]
|
| 83 |
+
assert ax.lines[0].get_color() == 'red'
|
| 84 |
+
assert ax.lines[0].get_linewidth() == 2
|
| 85 |
+
plt.close(fig)
|
| 86 |
+
|
| 87 |
+
# Test 2D styling
|
| 88 |
+
fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1)
|
| 89 |
+
assert fig is not None
|
| 90 |
+
plt.close(fig)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_auto_dimension_detection():
|
| 94 |
+
"""Test automatic dimension detection."""
|
| 95 |
+
_, da_2d = create_sample_data()
|
| 96 |
+
|
| 97 |
+
# Should work without specifying dimensions
|
| 98 |
+
fig = plot_2d(da_2d)
|
| 99 |
+
assert fig is not None
|
| 100 |
+
plt.close(fig)
|