Nipun Claude commited on
Commit
433dab5
Β·
1 Parent(s): 3aaaab0

🌍 TensorView v1.0 - Complete NetCDF/HDF/GRIB viewer

Browse files

Features:
- Multi-dimensional data exploration with smart slicing
- Geographic map plotting with Cartopy projections
- Automatic percentile-based color scaling
- Support for NetCDF, HDF5, GRIB, Zarr formats
- Interactive Gradio interface with dynamic controls
- File upload + direct path input modes
- Comprehensive scientific data visualization

πŸ€– Generated with Claude Code

Co-Authored-By: Claude <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # Jupyter Notebook
33
+ .ipynb_checkpoints
34
+
35
+ # Temporary files
36
+ *.tmp
37
+ *.temp
38
+ *~
39
+
40
+ # OS generated files
41
+ .DS_Store
42
+ .DS_Store?
43
+ ._*
44
+ .Spotlight-V100
45
+ .Trashes
46
+ ehthumbs.db
47
+ Thumbs.db
48
+
49
+ # Test outputs
50
+ test_*.png
51
+ test_*.nc
52
+
53
+ # Debug files
54
+ debug_*.py
55
+ app_simple.py
56
+ app_fixed.py
57
+
58
+ # Large data files
59
+ *.nc
60
+ *.hdf
61
+ *.h5
62
+ *.grib
63
+ *.grb
64
+ tests/data/*.nc
README.md CHANGED
@@ -1,12 +1,96 @@
1
  ---
2
- title: Ncview
3
- emoji: πŸŒ–
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TensorView - NetCDF/HDF/GRIB Viewer
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # 🌍 TensorView - Interactive Geospatial Data Viewer
14
+
15
+ A powerful browser-based viewer for **NetCDF**, **HDF**, **GRIB**, and **Zarr** datasets with advanced visualization capabilities.
16
+
17
+ ## πŸš€ Features
18
+
19
+ - **πŸ“Š Multi-dimensional data exploration** - Handle complex scientific datasets with automatic slicing
20
+ - **πŸ—ΊοΈ Geographic mapping** - Built-in map projections with coastlines and gridlines
21
+ - **🎨 Smart color scaling** - Automatic percentile-based color limits for optimal visualization
22
+ - **πŸ”„ Multiple data formats** - NetCDF, HDF5, GRIB, Zarr support
23
+ - **πŸŽ›οΈ Interactive controls** - Dynamic sliders for dimension exploration
24
+ - **πŸ“€ Dual input modes** - File upload or direct file path input
25
+ - **🌐 Remote data support** - Load from URLs, OPeNDAP, THREDDS servers
26
+
27
+ ## 🎯 Quick Start
28
+
29
+ 1. **Upload a file** or enter a file path
30
+ 2. **Select a variable** from the dropdown
31
+ 3. **Choose plot type**: 2D Image or Map (for geographic data)
32
+ 4. **Adjust dimension sliders** to explore different time steps, pressure levels, etc.
33
+ 5. **Create plot** and explore your data!
34
+
35
+ ## πŸ“Š Supported Data Sources
36
+
37
+ - **NetCDF files** (.nc, .netcdf) - Climate and weather data
38
+ - **HDF5 files** (.h5, .hdf) - Scientific datasets
39
+ - **GRIB files** (.grib, .grb) - Meteorological data
40
+ - **Zarr stores** - Cloud-optimized arrays
41
+ - **Remote URLs** - HTTP/HTTPS links to data files
42
+ - **OPeNDAP/THREDDS** - Direct server access
43
+
44
+ ## 🌟 Example Use Cases
45
+
46
+ - **Climate Data**: ERA5 reanalysis, CMIP model outputs
47
+ - **Weather Data**: GFS/ECMWF forecasts, radar data
48
+ - **Air Quality**: CAMS atmospheric composition data
49
+ - **Oceanography**: Sea surface temperature, currents
50
+ - **Satellite Data**: Remote sensing products
51
+
52
+ ## πŸ”§ Technical Details
53
+
54
+ Built with:
55
+ - **xarray + Dask** - Efficient handling of large datasets
56
+ - **matplotlib + Cartopy** - High-quality plotting and maps
57
+ - **Gradio** - Interactive web interface
58
+ - **Multi-engine support** - h5netcdf, netcdf4, cfgrib, zarr
59
+
60
+ ### Smart Features
61
+ - **Automatic color scaling** using 2nd-98th percentiles
62
+ - **Dimension detection** with dynamic slider generation
63
+ - **Geographic coordinate recognition** for map plotting
64
+ - **Memory-efficient** lazy loading with Dask
65
+
66
+ ## πŸ’‘ Tips
67
+
68
+ - For **5D data** (like CAMS forecasts): Use sliders to select time, pressure level, etc.
69
+ - For **geographic data**: Choose "Map" plot type for proper projections
70
+ - **Large files**: The app handles big datasets efficiently with lazy loading
71
+ - **Color issues**: The app automatically optimizes color scaling to avoid uniform plots
72
+
73
+ ## πŸ—οΈ Architecture
74
+
75
+ ```
76
+ tensorview/
77
+ β”œβ”€β”€ io.py # Data loading (NetCDF, HDF, GRIB, Zarr)
78
+ β”œβ”€β”€ plot.py # Visualization (1D, 2D, maps)
79
+ β”œβ”€β”€ grid.py # Data operations and alignment
80
+ β”œβ”€β”€ colors.py # Colormap handling
81
+ β”œβ”€β”€ utils.py # Coordinate inference
82
+ └── ...
83
+ ```
84
+
85
+ ## πŸ“ Example Datasets
86
+
87
+ The app works great with:
88
+ - NASA Goddard Earth Sciences Data
89
+ - ECMWF ERA5 reanalysis
90
+ - NOAA climate datasets
91
+ - Copernicus atmosphere monitoring (CAMS)
92
+ - CMIP climate model outputs
93
+
94
+ ---
95
+
96
+ **πŸ”— Links**: [GitHub Repository](https://github.com/user/tensorview) | [Documentation](https://docs.tensorview.io)
SPECS.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPECS.md β€” Panoply-Lite (Python/Gradio)
2
+
3
+ Goal: A browser-based viewer for **netCDF/HDF/GRIB/Zarr** datasets with an **xarray+Dask** backend, **Cartopy** maps, multi-dimensional slicing, animations, custom color tables (CPT/ACT/RGB), map projections, **OPeNDAP/THREDDS/S3/Zarr** support, and exports (PNG/SVG/PDF/MP4). Think β€œPanoply in the browser.”
4
+
5
+ ---
6
+
7
+ ## 0) TL;DR Build Contract
8
+
9
+ - **Deliverable**: `app.py` + `panlite/` package + tests + pinned `pyproject.toml` (or `requirements.txt`) + minimal sample assets.
10
+ - **Runtime**: Python 3.11; CPU-only acceptable; FFmpeg required for MP4.
11
+ - **UI**: Gradio Blocks (single page).
12
+ - **Perf**: Lazy I/O with Dask; responsive for slices ≀ ~5e6 elements.
13
+ - **Outcomes (Definition of Done)**: Load ERA5/CMIP-like datasets; produce a global map, a time–latitude section, a 1D line plot, an animation (MP4), and an A–B difference contour; export PNG/SVG/PDF; save/load view state JSON.
14
+
15
+ ---
16
+
17
+ ## 1) Scope
18
+
19
+ ### 1.1 MVP Features
20
+ - **Open sources**: local files; `https://` (incl. **OPeNDAP/THREDDS**); `s3://` (anonymous or creds); `zarr://` (local/remote).
21
+ - **Formats/engines**:
22
+ - netCDF/HDF5 β†’ `xarray.open_dataset(..., engine="h5netcdf" | "netcdf4")`
23
+ - GRIB β†’ `cfgrib` (requires ecCodes)
24
+ - Zarr β†’ `xarray.open_zarr(...)`
25
+ - **Discovery**: list variables, dims, shapes, dtypes, attrs; CF axis inference (lat, lon, time, level).
26
+ - **Slicing**: choose X/Y axes; set index/value selectors for remaining dims (nearest when numeric); time selectors.
27
+ - **Plots**:
28
+ - 1D: line (any single dim)
29
+ - 2D: image/contour (any two dims)
30
+ - Map: lon–lat georeferenced 2D with Cartopy (projections, coastlines, gridlines)
31
+ - Sections: time–lat, time–lev, lon–lev (pick dims)
32
+ - **Combine**: sum/avg/diff of two variables on a common grid (with reindex/broadcast; error out if CRS differs).
33
+ - **Colors**: matplotlib colormaps + load **CPT/ACT/RGB** tables.
34
+ - **Overlays**: coastlines, borders; optional land/ocean masks.
35
+ - **Export**: PNG/SVG/PDF; **MP4** (or frames) for animations.
36
+ - **State**: save/load a JSON β€œview state.”
37
+
38
+ ### 1.2 v1.1 Stretch (optional)
39
+ - **KMZ** export (ground overlay tiles).
40
+ - **GeoTIFF** export for georeferenced 2D arrays.
41
+ - **Trajectory** plots (CF trajectories).
42
+ - **Zonal mean** helper (lat/lon aggregation).
43
+ - **xESMF** regridding for combine.
44
+
45
+ ### 1.3 Non-Goals (v1)
46
+ 3D volume rendering; advanced reprojection pipelines; interactive WebGL.
47
+
48
+ ---
49
+
50
+ ## 2) Architecture & Layout
51
+
52
+ panoply-lite/
53
+ app.py
54
+ panlite/
55
+ __init__.py
56
+ io.py # open/close; engine select; cache; variable listing
57
+ grid.py # alignment, simple reindex/broadcast; combine ops; sections
58
+ plot.py # 1D/2D/map plotting; exports
59
+ anim.py # animate over a dim -> MP4/frames
60
+ colors.py # CPT/ACT/RGB loaders -> matplotlib Colormap
61
+ state.py # view-state (serialize/deserialize; schema validate)
62
+ utils.py # CF axis guess; CRS helpers; small helpers
63
+ assets/
64
+ colormaps/ # sample .cpt/.act/.rgb
65
+ tests/
66
+ test_io.py
67
+ test_plot.py
68
+ test_anim.py
69
+ data/ # tiny sample datasets (≀5 MB)
70
+ pyproject.toml (or requirements.txt)
71
+ README.md
72
+ SPECS.md
73
+
74
+ ---
75
+
76
+ ## 3) Dependencies (pin reasonably)
77
+
78
+ - Core: xarray, dask[complete], numpy, pandas, fsspec, s3fs, zarr, h5netcdf, cfgrib, eccodes
79
+ - Geo: cartopy, pyproj, rioxarray
80
+ - Viz: matplotlib
81
+ - Web UI: gradio>=4
82
+ - Misc: pydantic, orjson, pytest, ruff
83
+ - System: ffmpeg in PATH for MP4
84
+
85
+ requirements.txt example:
86
+
87
+ xarray
88
+ dask[complete]
89
+ fsspec
90
+ s3fs
91
+ zarr
92
+ h5netcdf
93
+ cfgrib
94
+ eccodes
95
+ rioxarray
96
+ cartopy
97
+ pyproj
98
+ matplotlib
99
+ gradio>=4
100
+ numpy
101
+ pandas
102
+ pydantic
103
+ orjson
104
+ pytest
105
+
106
+ ---
107
+
108
+ ## 4) API Surface
109
+
110
+ ### io.py
111
+ - open_any(uri, engine=None, chunks="auto") -> DatasetHandle
112
+ - list_variables(handle) -> list[VariableSpec]
113
+ - get_dataarray(handle, var) -> xr.DataArray
114
+ - close(handle)
115
+
116
+ ### grid.py
117
+ - align_for_combine(a, b, method="reindex")
118
+ - combine(a, b, op="sum"|"avg"|"diff")
119
+ - section(da, along: str, fixed: dict)
120
+
121
+ ### plot.py
122
+ - plot_1d(da, **style) -> Figure
123
+ - plot_2d(da, kind="image"|"contour", **style) -> Figure
124
+ - plot_map(da, proj, **style) -> Figure
125
+ - export_fig(fig, fmt="png"|"svg"|"pdf", dpi=150, out_path=None)
126
+
127
+ ### anim.py
128
+ - animate_over_dim(da, dim: str, fps=10, out="anim.mp4") -> str
129
+
130
+ ### colors.py
131
+ - load_cpt(path) -> Colormap
132
+ - load_act(path) -> Colormap
133
+ - load_rgb(path) -> Colormap
134
+
135
+ ### state.py
136
+ - dump_state(dict) -> str (JSON)
137
+ - load_state(str) -> dict
138
+
139
+ ---
140
+
141
+ ## 5) UI Spec (Gradio Blocks)
142
+
143
+ Sidebar:
144
+ - File open (local, URL, S3)
145
+ - Dataset vars: pick A, optional B; operation (sum/avg/diff)
146
+ - Axes: X, Y; slice others (sliders, dropdowns)
147
+ - Plot type: 1D, 2D image, 2D contour, Map
148
+ - Projection: PlateCarree, Robinson, etc
149
+ - Colormap: dropdown + upload CPT/ACT/RGB
150
+ - Colorbar options; vmin/vmax; levels
151
+ - Animation: dim=[time|lev], FPS, MP4 export
152
+ - Export: PNG/SVG/PDF
153
+ - Save/load view state JSON
154
+
155
+ Main Panel:
156
+ - Figure canvas (matplotlib)
157
+ - Metadata panel: units, attrs, dims
158
+
159
+ ---
160
+
161
+ ## 6) Acceptance Criteria
162
+
163
+ - Open ERA5 (Zarr/OPeNDAP) dataset.
164
+ - Plot: (a) global 2D map with Robinson projection; (b) time–lat section; (c) 1D time series.
165
+ - Perform A–B difference contour plot.
166
+ - Load custom CPT colormap and export as SVG/PDF.
167
+ - Animate over time dimension (24 frames), export MP4.
168
+ - Save view state, reload reproduces identical figure.
169
+
170
+ ---
171
+
172
+ ## 7) Testing Checklist
173
+
174
+ - Local netCDF, remote OPeNDAP, public S3, Zarr open successfully.
175
+ - Variable discovery works; CF axis inference correct.
176
+ - 1D/2D/Map plotting functional.
177
+ - Combine AΒ±B correct (aligned grid).
178
+ - Custom CPT colormap applied correctly.
179
+ - Export PNG/SVG/PDF correct dimensions/DPI.
180
+ - Animation over time produces correct frame count, valid MP4.
181
+ - Large datasets responsive due to Dask.
app.py CHANGED
@@ -1,444 +1,334 @@
 
 
 
1
  import gradio as gr
2
- import xarray as xr
3
- import numpy as np
4
  import matplotlib.pyplot as plt
5
- import matplotlib.patches as patches
6
- import plotly.express as px
7
- import plotly.graph_objects as go
8
- from plotly.subplots import make_subplots
9
- import pandas as pd
10
- import tempfile
11
- import os
12
- from typing import Optional, Tuple, Dict, Any
13
 
14
- # Set matplotlib backend
15
- plt.switch_backend('Agg')
 
16
 
17
- def analyze_netcdf(file_path: str) -> Tuple[str, Dict[str, Any]]:
18
- """Analyze NetCDF file and extract metadata."""
 
 
 
 
 
 
 
 
 
 
19
  try:
20
- ds = xr.open_dataset(file_path)
21
-
22
- # Basic info
23
- info = {
24
- 'dimensions': dict(ds.dims),
25
- 'variables': list(ds.data_vars.keys()),
26
- 'coordinates': list(ds.coords.keys()),
27
- 'attrs': dict(ds.attrs),
28
- 'data_vars_info': {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  }
30
 
31
- # Detailed variable information
32
- for var in ds.data_vars:
33
- var_info = {
34
- 'shape': ds[var].shape,
35
- 'dtype': str(ds[var].dtype),
36
- 'dims': ds[var].dims,
37
- 'attrs': dict(ds[var].attrs),
38
- 'min': float(ds[var].min().values) if ds[var].size > 0 else None,
39
- 'max': float(ds[var].max().values) if ds[var].size > 0 else None,
40
- 'mean': float(ds[var].mean().values) if ds[var].size > 0 else None
41
- }
42
- info['data_vars_info'][var] = var_info
43
-
44
- # Generate summary text
45
- summary = f"""
46
- ## Dataset Overview
47
- - **Dimensions**: {len(ds.dims)} ({', '.join([f"{k}: {v}" for k, v in ds.dims.items()])})
48
- - **Variables**: {len(ds.data_vars)} data variables, {len(ds.coords)} coordinates
49
- - **Global Attributes**: {len(ds.attrs)} attributes
50
 
51
- ### Variables:
52
  """
53
- for var, var_info in info['data_vars_info'].items():
54
- summary += f"- **{var}**: {var_info['shape']} ({var_info['dtype']})"
55
- if var_info['min'] is not None:
56
- summary += f" [{var_info['min']:.2f} to {var_info['max']:.2f}]"
57
- summary += "\n"
58
 
59
- ds.close()
60
- return summary, info
61
 
62
  except Exception as e:
63
- return f"Error analyzing file: {str(e)}", {}
64
 
65
- def create_2d_plot(file_path: str, variable: str, time_idx: int = 0, level_idx: int = 0,
66
- colormap: str = "viridis") -> go.Figure:
67
- """Create 2D visualization of NetCDF data."""
 
 
 
 
 
 
 
 
68
  try:
69
- ds = xr.open_dataset(file_path)
70
-
71
- if variable not in ds.data_vars:
72
- raise ValueError(f"Variable '{variable}' not found in dataset")
73
 
74
- data_var = ds[variable]
 
 
75
 
76
- # Handle different dimensional data
77
- if len(data_var.dims) >= 2:
78
- # Find spatial dimensions (usually lat/lon or x/y)
79
- spatial_dims = []
80
- for dim in data_var.dims:
81
- if any(name in dim.lower() for name in ['lat', 'lon', 'x', 'y']):
82
- spatial_dims.append(dim)
83
 
84
- if len(spatial_dims) >= 2:
85
- # Use the last two spatial dimensions
86
- dim1, dim2 = spatial_dims[-2:]
87
-
88
- # Select subset based on other dimensions
89
- data_subset = data_var
90
- for dim in data_var.dims:
91
- if dim not in [dim1, dim2]:
92
- if 'time' in dim.lower():
93
- data_subset = data_subset.isel({dim: min(time_idx, data_var.sizes[dim]-1)})
94
- elif any(name in dim.lower() for name in ['level', 'depth', 'height']):
95
- data_subset = data_subset.isel({dim: min(level_idx, data_var.sizes[dim]-1)})
96
- else:
97
- data_subset = data_subset.isel({dim: 0})
98
  else:
99
- # Use first two dimensions
100
- dims = list(data_var.dims)
101
- if len(dims) >= 2:
102
- data_subset = data_var.isel({dim: 0 for dim in dims[2:]})
103
- else:
104
- data_subset = data_var
105
- dim1, dim2 = dims[:2]
106
- else:
107
- raise ValueError("Data must have at least 2 dimensions for 2D plotting")
108
-
109
- # Create the plot
110
- fig = go.Figure(data=go.Heatmap(
111
- z=data_subset.values,
112
- x=data_subset.coords[dim2].values if dim2 in data_subset.coords else None,
113
- y=data_subset.coords[dim1].values if dim1 in data_subset.coords else None,
114
- colorscale=colormap,
115
- colorbar=dict(title=data_var.attrs.get('units', 'Value'))
116
- ))
117
-
118
- fig.update_layout(
119
- title=f"{variable} - {data_var.attrs.get('long_name', variable)}",
120
- xaxis_title=dim2,
121
- yaxis_title=dim1,
122
- height=600,
123
- width=800
124
- )
125
-
126
- ds.close()
127
- return fig
128
-
129
- except Exception as e:
130
- # Return empty figure with error message
131
- fig = go.Figure()
132
- fig.add_annotation(
133
- text=f"Error creating plot: {str(e)}",
134
- x=0.5, y=0.5,
135
- xref="paper", yref="paper",
136
- showarrow=False,
137
- font=dict(size=16, color="red")
138
- )
139
- return fig
140
-
141
- def create_time_series(file_path: str, variable: str, method: str = "mean") -> go.Figure:
142
- """Create time series plot by aggregating spatial dimensions."""
143
- try:
144
- ds = xr.open_dataset(file_path)
145
-
146
- if variable not in ds.data_vars:
147
- raise ValueError(f"Variable '{variable}' not found in dataset")
148
-
149
- data_var = ds[variable]
150
-
151
- # Find time dimension
152
- time_dim = None
153
- for dim in data_var.dims:
154
- if 'time' in dim.lower():
155
- time_dim = dim
156
- break
157
-
158
- if time_dim is None:
159
- raise ValueError("No time dimension found in the data")
160
-
161
- # Aggregate spatial dimensions
162
- spatial_dims = [dim for dim in data_var.dims if dim != time_dim]
163
 
164
- if method == "mean":
165
- time_series = data_var.mean(dim=spatial_dims)
166
- elif method == "max":
167
- time_series = data_var.max(dim=spatial_dims)
168
- elif method == "min":
169
- time_series = data_var.min(dim=spatial_dims)
170
  else:
171
- time_series = data_var.mean(dim=spatial_dims)
172
 
173
- fig = go.Figure(data=go.Scatter(
174
- x=time_series.coords[time_dim].values,
175
- y=time_series.values,
176
- mode='lines+markers',
177
- name=f"{method.title()} {variable}"
178
- ))
179
 
180
- fig.update_layout(
181
- title=f"Time Series: {method.title()} {variable}",
182
- xaxis_title="Time",
183
- yaxis_title=f"{variable} ({data_var.attrs.get('units', 'Value')})",
184
- height=400
185
- )
186
-
187
- ds.close()
188
- return fig
189
-
190
- except Exception as e:
191
- fig = go.Figure()
192
- fig.add_annotation(
193
- text=f"Error creating time series: {str(e)}",
194
- x=0.5, y=0.5,
195
- xref="paper", yref="paper",
196
- showarrow=False,
197
- font=dict(size=16, color="red")
198
- )
199
- return fig
200
-
201
- def create_vertical_profile(file_path: str, variable: str, time_idx: int = 0) -> go.Figure:
202
- """Create vertical profile plot."""
203
- try:
204
- ds = xr.open_dataset(file_path)
205
-
206
- if variable not in ds.data_vars:
207
- raise ValueError(f"Variable '{variable}' not found in dataset")
208
-
209
- data_var = ds[variable]
210
-
211
- # Find vertical dimension
212
- vertical_dim = None
213
- for dim in data_var.dims:
214
- if any(name in dim.lower() for name in ['level', 'depth', 'height', 'pressure']):
215
- vertical_dim = dim
216
- break
217
-
218
- if vertical_dim is None:
219
- raise ValueError("No vertical dimension found in the data")
220
-
221
- # Average over horizontal dimensions, select time
222
- dims_to_avg = []
223
- for dim in data_var.dims:
224
- if dim != vertical_dim:
225
- if 'time' in dim.lower():
226
- data_var = data_var.isel({dim: min(time_idx, data_var.sizes[dim]-1)})
227
- else:
228
- dims_to_avg.append(dim)
229
-
230
- if dims_to_avg:
231
- profile = data_var.mean(dim=dims_to_avg)
232
  else:
233
- profile = data_var
234
-
235
- fig = go.Figure(data=go.Scatter(
236
- x=profile.values,
237
- y=profile.coords[vertical_dim].values,
238
- mode='lines+markers',
239
- name=variable
240
- ))
 
241
 
242
- fig.update_layout(
243
- title=f"Vertical Profile: {variable}",
244
- xaxis_title=f"{variable} ({data_var.attrs.get('units', 'Value')})",
245
- yaxis_title=vertical_dim,
246
- height=500
247
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
- ds.close()
250
- return fig
251
 
252
  except Exception as e:
253
- fig = go.Figure()
254
- fig.add_annotation(
255
- text=f"Error creating profile: {str(e)}",
256
- x=0.5, y=0.5,
257
- xref="paper", yref="paper",
258
- showarrow=False,
259
- font=dict(size=16, color="red")
260
- )
261
- return fig
262
 
263
- def process_netcdf_file(file):
264
- """Process uploaded NetCDF file and return analysis."""
265
- if file is None:
266
- return "Please upload a NetCDF file.", {}, [], []
267
 
268
- try:
269
- # Save uploaded file temporarily
270
- with tempfile.NamedTemporaryFile(delete=False, suffix='.nc') as tmp_file:
271
- tmp_file.write(file.read())
272
- tmp_path = tmp_file.name
273
-
274
- # Analyze the file
275
- summary, info = analyze_netcdf(tmp_path)
276
-
277
- # Get variable options
278
- variable_options = list(info.get('data_vars_info', {}).keys())
279
-
280
- # Get dimension options for slicing
281
- dimensions = info.get('dimensions', {})
282
-
283
- return summary, tmp_path, variable_options, list(dimensions.keys())
284
-
285
- except Exception as e:
286
- return f"Error processing file: {str(e)}", "", [], []
287
-
288
- def update_plot(file_path: str, variable: str, plot_type: str, time_idx: int,
289
- level_idx: int, colormap: str, aggregation_method: str):
290
- """Update plot based on user selections."""
291
- if not file_path or not variable:
292
- return go.Figure()
293
 
294
- try:
295
- if plot_type == "2D Heatmap":
296
- return create_2d_plot(file_path, variable, time_idx, level_idx, colormap)
297
- elif plot_type == "Time Series":
298
- return create_time_series(file_path, variable, aggregation_method)
299
- elif plot_type == "Vertical Profile":
300
- return create_vertical_profile(file_path, variable, time_idx)
301
- else:
302
- return go.Figure()
303
- except Exception as e:
304
- fig = go.Figure()
305
- fig.add_annotation(
306
- text=f"Error: {str(e)}",
307
- x=0.5, y=0.5,
308
- xref="paper", yref="paper",
309
- showarrow=False
310
- )
311
- return fig
312
-
313
- # Create Gradio interface
314
- with gr.Blocks(title="NetCDF Explorer 🌍", theme=gr.themes.Soft()) as app:
315
- gr.Markdown("""
316
- # 🌍 NetCDF Explorer
317
 
318
- Upload and explore NetCDF (.nc) files with interactive visualizations!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- **Features:**
321
- - πŸ“Š Interactive 2D heatmaps
322
- - πŸ“ˆ Time series analysis
323
- - πŸ“‰ Vertical profiles
324
- - 🎨 Customizable colormaps
325
- - πŸ“‹ Comprehensive metadata analysis
326
- """)
327
 
328
- # File upload section
329
- with gr.Row():
330
- file_upload = gr.File(
331
- label="Upload NetCDF File (.nc)",
332
- file_types=[".nc", ".netcdf"],
333
- type="binary"
334
- )
335
 
336
- # File analysis section
337
- with gr.Row():
338
- file_info = gr.Markdown("Upload a file to see its structure and metadata.")
 
 
 
 
339
 
340
- # Control panel
341
  with gr.Row():
342
  with gr.Column(scale=1):
343
- variable_dropdown = gr.Dropdown(
344
- label="Select Variable",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  choices=[],
346
  interactive=True
347
  )
348
 
349
  plot_type = gr.Radio(
350
  label="Plot Type",
351
- choices=["2D Heatmap", "Time Series", "Vertical Profile"],
352
- value="2D Heatmap"
353
  )
354
 
355
- colormap_dropdown = gr.Dropdown(
356
  label="Colormap",
357
- choices=["viridis", "plasma", "inferno", "magma", "cividis",
358
- "Blues", "Reds", "RdYlBu", "RdBu", "coolwarm"],
359
  value="viridis"
360
  )
361
 
362
- aggregation_method = gr.Radio(
363
- label="Time Series Aggregation",
364
- choices=["mean", "max", "min"],
365
- value="mean",
366
- visible=False
367
- )
368
 
369
- with gr.Column(scale=1):
370
- time_slider = gr.Slider(
371
- label="Time Index",
372
- minimum=0,
373
- maximum=100,
374
- value=0,
375
- step=1
376
- )
377
 
378
- level_slider = gr.Slider(
379
- label="Level Index",
380
- minimum=0,
381
- maximum=100,
382
- value=0,
383
- step=1
384
- )
385
-
386
- update_btn = gr.Button("Update Plot", variant="primary")
387
-
388
- # Plot display
389
- plot_display = gr.Plot(label="Visualization")
390
 
391
- # Hidden state to store file path
392
  file_path_state = gr.State("")
393
 
394
  # Event handlers
395
- def on_file_upload(file):
396
- summary, tmp_path, variables, dimensions = process_netcdf_file(file)
397
-
398
- # Update UI components
399
- updates = [
400
- gr.update(value=summary), # file_info
401
- gr.update(choices=variables, value=variables[0] if variables else None), # variable_dropdown
402
- gr.update(value=tmp_path), # file_path_state
403
- ]
404
-
405
- return updates
406
 
407
- def on_plot_type_change(plot_type_val):
408
- if plot_type_val == "Time Series":
409
- return gr.update(visible=True)
410
- else:
411
- return gr.update(visible=False)
 
 
 
 
 
412
 
413
- def on_update_plot(file_path, variable, plot_type_val, time_idx, level_idx, colormap, agg_method):
414
- return update_plot(file_path, variable, plot_type_val, int(time_idx), int(level_idx), colormap, agg_method)
 
415
 
416
- # Connect event handlers
417
  file_upload.upload(
418
- fn=on_file_upload,
419
  inputs=[file_upload],
420
- outputs=[file_info, variable_dropdown, file_path_state]
421
  )
422
 
423
- plot_type.change(
424
- fn=on_plot_type_change,
425
- inputs=[plot_type],
426
- outputs=[aggregation_method]
427
  )
428
 
429
- update_btn.click(
430
- fn=on_update_plot,
431
- inputs=[file_path_state, variable_dropdown, plot_type, time_slider,
432
- level_slider, colormap_dropdown, aggregation_method],
433
- outputs=[plot_display]
434
  )
435
 
436
- # Auto-update on variable change
437
- variable_dropdown.change(
438
- fn=on_update_plot,
439
- inputs=[file_path_state, variable_dropdown, plot_type, time_slider,
440
- level_slider, colormap_dropdown, aggregation_method],
441
- outputs=[plot_display]
442
  )
443
 
444
  if __name__ == "__main__":
 
1
+ """TensorView - Final version for Hugging Face deployment."""
2
+
3
+ import os
4
  import gradio as gr
 
 
5
  import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from typing import Optional, Dict, Any, List, Tuple
8
+
9
+ # TensorView imports
10
+ from tensorview.io import open_any, list_variables, get_dataarray, close, DatasetHandle
11
+ from tensorview.plot import plot_1d, plot_2d, plot_map, export_fig
12
+ from tensorview.colors import get_matplotlib_colormaps
13
+ from tensorview.utils import identify_coordinates
14
 
15
+ # Global state
16
+ current_handle: Optional[DatasetHandle] = None
17
+ current_data: Optional[Dict[str, Any]] = None
18
 
19
+ def process_file_or_path(file_or_path) -> Tuple[str, List[str], str]:
20
+ """Process uploaded file or file path."""
21
+ global current_handle, current_data
22
+
23
+ # Determine if it's a file object or path string
24
+ if hasattr(file_or_path, 'name'):
25
+ file_path = file_or_path.name
26
+ source = f"Uploaded: {os.path.basename(file_path)}"
27
+ else:
28
+ file_path = str(file_or_path)
29
+ source = f"Path: {os.path.basename(file_path)}"
30
+
31
  try:
32
+ # Close previous handle
33
+ if current_handle:
34
+ close(current_handle)
35
+
36
+ # Open dataset
37
+ current_handle = open_any(file_path)
38
+ variables = list_variables(current_handle)
39
+
40
+ # Get variable info and dataset info
41
+ var_choices = [v.name for v in variables]
42
+ ds = current_handle.dataset
43
+
44
+ # Store dimension information
45
+ dims_info = {}
46
+ for dim_name, dim_size in ds.dims.items():
47
+ if dim_name in ds.coords:
48
+ coord = ds.coords[dim_name]
49
+ dims_info[dim_name] = {
50
+ 'size': dim_size,
51
+ 'values': coord.values,
52
+ 'min': float(coord.min().values),
53
+ 'max': float(coord.max().values)
54
+ }
55
+
56
+ current_data = {
57
+ 'variables': {v.name: {'dims': v.dims, 'shape': v.shape, 'long_name': v.long_name, 'units': v.units}
58
+ for v in variables},
59
+ 'dimensions': dims_info
60
  }
61
 
62
+ # Create summary
63
+ summary = f"""βœ… **Dataset Loaded Successfully!**
64
+
65
+ **Source:** {source}
66
+ **Engine:** {current_handle.engine}
67
+ **Dimensions:** {len(ds.dims)} ({', '.join([f"{k}: {v}" for k, v in ds.dims.items()])})
68
+ **Variables:** {len(var_choices)}
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ ### Available Variables:
71
  """
72
+ for v in variables:
73
+ summary += f"- **{v.name}**: {v.shape} [{v.units}] - {v.long_name}\n"
 
 
 
74
 
75
+ return summary, var_choices, file_path
 
76
 
77
  except Exception as e:
78
+ return f"❌ **Error loading data:** {str(e)}", [], ""
79
 
80
+ def create_plot_with_auto_slicing(file_path: str, variable: str, plot_type: str,
81
+ colormap: str, *slider_values) -> Tuple[plt.Figure, str]:
82
+ """Create plot with automatic dimension reduction."""
83
+ global current_handle, current_data
84
+
85
+ if not current_handle or not variable or not current_data:
86
+ fig, ax = plt.subplots(figsize=(10, 6))
87
+ ax.text(0.5, 0.5, '🚫 No data loaded', ha='center', va='center',
88
+ transform=ax.transAxes, fontsize=16)
89
+ return fig, ""
90
+
91
  try:
92
+ # Get data
93
+ da = get_dataarray(current_handle, variable)
94
+ var_info = current_data['variables'][variable]
 
95
 
96
+ # Apply dimension slicing
97
+ selection_dict = {}
98
+ slider_idx = 0
99
 
100
+ for dim in da.dims:
101
+ if any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
102
+ continue # Keep spatial dimensions
 
 
 
 
103
 
104
+ # Use slider value or default to 0
105
+ if slider_idx < len(slider_values) and slider_values[slider_idx] is not None:
106
+ selection_dict[dim] = int(slider_values[slider_idx])
 
 
 
 
 
 
 
 
 
 
 
107
  else:
108
+ selection_dict[dim] = 0
109
+ slider_idx += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Apply slicing
112
+ if selection_dict:
113
+ da_plot = da.isel(selection_dict)
 
 
 
114
  else:
115
+ da_plot = da
116
 
117
+ # Calculate smart color limits using percentiles
118
+ data_values = da_plot.values
119
+ finite_data = data_values[np.isfinite(data_values)]
 
 
 
120
 
121
+ if len(finite_data) > 0:
122
+ data_min = float(np.percentile(finite_data, 2))
123
+ data_max = float(np.percentile(finite_data, 98))
124
+
125
+ # Ensure min != max
126
+ if abs(data_max - data_min) < 1e-10:
127
+ data_range = abs(data_min) * 0.1 if data_min != 0 else 1e-6
128
+ data_min -= data_range
129
+ data_max += data_range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
+ data_min, data_max = 0, 1
132
+
133
+ # Style parameters
134
+ style = {
135
+ 'cmap': colormap,
136
+ 'colorbar': True,
137
+ 'vmin': data_min,
138
+ 'vmax': data_max
139
+ }
140
 
141
+ # Create appropriate plot
142
+ n_dims = len(da_plot.dims)
143
+
144
+ if n_dims == 1:
145
+ fig = plot_1d(da_plot, **style)
146
+ elif n_dims == 2:
147
+ coords = identify_coordinates(da_plot)
148
+ if 'X' in coords and 'Y' in coords and plot_type == "Map":
149
+ style.update({
150
+ 'proj': 'PlateCarree',
151
+ 'coastlines': True,
152
+ 'gridlines': True
153
+ })
154
+ fig = plot_map(da_plot, **style)
155
+ else:
156
+ fig = plot_2d(da_plot, kind="image", **style)
157
+ else:
158
+ fig, ax = plt.subplots(figsize=(10, 6))
159
+ ax.text(0.5, 0.5, f'πŸ”§ Cannot plot {n_dims}D data.\nAdjust sliders to reduce dimensions.',
160
+ ha='center', va='center', transform=ax.transAxes, fontsize=12)
161
+ return fig, f"❌ Too many dimensions: {n_dims}D"
162
+
163
+ # Add selection info to title
164
+ if selection_dict:
165
+ selection_str = ', '.join([f"{k}={v}" for k, v in selection_dict.items()])
166
+ fig.suptitle(f"{var_info['long_name']} ({selection_str})", fontsize=14)
167
+
168
+ # Create info string
169
+ data_info = f"""πŸ“Š **Plot Info:**
170
+ - **Shape:** {da_plot.shape}
171
+ - **Range:** {data_min:.4g} to {data_max:.4g} {var_info['units']}
172
+ - **Plot:** {plot_type} | **Colormap:** {colormap}
173
+ - **Valid data:** {len(finite_data):,} points
174
+ """
175
 
176
+ return fig, data_info
 
177
 
178
  except Exception as e:
179
+ fig, ax = plt.subplots(figsize=(10, 6))
180
+ ax.text(0.5, 0.5, f'❌ Error:\n{str(e)}', ha='center', va='center',
181
+ transform=ax.transAxes, fontsize=12, color='red')
182
+ return fig, f"❌ Error: {str(e)}"
 
 
 
 
 
183
 
184
+ def update_sliders_for_variable(variable: str):
185
+ """Update sliders based on selected variable."""
186
+ global current_data
 
187
 
188
+ if not variable or not current_data:
189
+ return [gr.update(visible=False) for _ in range(8)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ var_info = current_data['variables'][variable]
192
+ dims = var_info['dims']
193
+ dims_info = current_data['dimensions']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ updates = []
196
+ for dim in dims:
197
+ if any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
198
+ continue # Skip spatial dimensions
199
+
200
+ if len(updates) < 8 and dim in dims_info:
201
+ dim_data = dims_info[dim]
202
+ updates.append(gr.update(
203
+ visible=True,
204
+ label=f"{dim} (0-{dim_data['size']-1})",
205
+ minimum=0,
206
+ maximum=dim_data['size'] - 1,
207
+ value=0,
208
+ step=1
209
+ ))
210
 
211
+ # Hide remaining sliders
212
+ while len(updates) < 8:
213
+ updates.append(gr.update(visible=False))
 
 
 
 
214
 
215
+ return updates
216
+
217
+ # Create Gradio Interface
218
+ with gr.Blocks(title="🌍 TensorView - NetCDF/HDF/GRIB Viewer", theme=gr.themes.Ocean()) as app:
 
 
 
219
 
220
+ gr.HTML("""
221
+ <div style="text-align: center; padding: 20px;">
222
+ <h1>🌍 TensorView</h1>
223
+ <p><strong>Interactive viewer for NetCDF, HDF, GRIB, and Zarr datasets</strong></p>
224
+ <p>Upload your data files or provide a file path, then explore with maps and plots!</p>
225
+ </div>
226
+ """)
227
 
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
+ # Data Input Section
231
+ gr.Markdown("### πŸ“ Data Input")
232
+ with gr.Tabs():
233
+ with gr.Tab("πŸ“€ Upload File"):
234
+ file_upload = gr.File(
235
+ label="Select data file",
236
+ file_types=[".nc", ".netcdf", ".hdf", ".h5", ".grib", ".grb"]
237
+ )
238
+ with gr.Tab("πŸ“‚ File Path"):
239
+ file_path_input = gr.Textbox(
240
+ label="Enter file path",
241
+ placeholder="/path/to/your/data.nc",
242
+ lines=2
243
+ )
244
+ load_path_btn = gr.Button("Load File", variant="primary")
245
+
246
+ # Plot Configuration
247
+ gr.Markdown("### 🎨 Plot Configuration")
248
+ variable_select = gr.Dropdown(
249
+ label="Variable",
250
  choices=[],
251
  interactive=True
252
  )
253
 
254
  plot_type = gr.Radio(
255
  label="Plot Type",
256
+ choices=["2D Image", "Map"],
257
+ value="2D Image"
258
  )
259
 
260
+ colormap_select = gr.Dropdown(
261
  label="Colormap",
262
+ choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds", "turbo"],
 
263
  value="viridis"
264
  )
265
 
266
+ # Dimension Sliders
267
+ gr.Markdown("### πŸŽ›οΈ Dimension Controls")
268
+ gr.Markdown("*Sliders appear automatically based on selected variable*")
 
 
 
269
 
270
+ slider_components = [
271
+ gr.Slider(visible=False, interactive=True) for _ in range(8)
272
+ ]
 
 
 
 
 
273
 
274
+ plot_btn = gr.Button("🎨 Create Plot", variant="primary", size="lg")
275
+
276
+ with gr.Column(scale=2):
277
+ # Results Section
278
+ data_info = gr.Markdown("*Load a dataset to begin*")
279
+ plot_display = gr.Plot(label="Visualization")
280
+ plot_info = gr.Markdown("")
 
 
 
 
 
281
 
282
+ # Hidden state
283
  file_path_state = gr.State("")
284
 
285
  # Event handlers
286
+ def handle_file_upload(file):
287
+ summary, variables, path = process_file_or_path(file)
288
+ return (
289
+ summary,
290
+ gr.update(choices=variables, value=variables[0] if variables else None),
291
+ path
292
+ )
 
 
 
 
293
 
294
+ def handle_path_load(file_path):
295
+ summary, variables, path = process_file_or_path(file_path)
296
+ return (
297
+ summary,
298
+ gr.update(choices=variables, value=variables[0] if variables else None),
299
+ path
300
+ )
301
+
302
+ def handle_variable_change(variable):
303
+ return update_sliders_for_variable(variable)
304
 
305
+ def handle_plot_creation(file_path, variable, plot_type, colormap, *slider_vals):
306
+ fig, info = create_plot_with_auto_slicing(file_path, variable, plot_type, colormap, *slider_vals)
307
+ return fig, info
308
 
309
+ # Connect events
310
  file_upload.upload(
311
+ fn=handle_file_upload,
312
  inputs=[file_upload],
313
+ outputs=[data_info, variable_select, file_path_state]
314
  )
315
 
316
+ load_path_btn.click(
317
+ fn=handle_path_load,
318
+ inputs=[file_path_input],
319
+ outputs=[data_info, variable_select, file_path_state]
320
  )
321
 
322
+ variable_select.change(
323
+ fn=handle_variable_change,
324
+ inputs=[variable_select],
325
+ outputs=slider_components
 
326
  )
327
 
328
+ plot_btn.click(
329
+ fn=handle_plot_creation,
330
+ inputs=[file_path_state, variable_select, plot_type, colormap_select] + slider_components,
331
+ outputs=[plot_display, plot_info]
 
 
332
  )
333
 
334
  if __name__ == "__main__":
assets/colormaps/sample.cpt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sample GMT-style colormap
2
+ # Blue to Red temperature colormap
3
+ -10.0 0 0 255 -5.0 50 50 255
4
+ -5.0 50 50 255 0.0 100 100 255
5
+ 0.0 100 100 255 5.0 150 200 255
6
+ 5.0 150 200 255 10.0 200 255 200
7
+ 10.0 200 255 200 15.0 255 255 100
8
+ 15.0 255 255 100 20.0 255 200 50
9
+ 20.0 255 200 50 25.0 255 150 0
10
+ 25.0 255 150 0 30.0 255 100 0
11
+ 30.0 255 100 0 35.0 255 50 0
12
+ 35.0 255 50 0 40.0 255 0 0
requirements.txt CHANGED
@@ -1,10 +1,20 @@
1
- gradio==5.42.0
2
- xarray>=2023.1.0
3
- netcdf4>=1.6.0
4
- numpy>=1.21.0
5
- matplotlib>=3.5.0
6
- plotly>=5.10.0
7
- pandas>=1.5.0
8
- scipy>=1.9.0
9
- h5netcdf>=1.0.0
10
- dask>=2022.1.0
 
 
 
 
 
 
 
 
 
 
 
1
+ xarray
2
+ dask[complete]
3
+ fsspec
4
+ s3fs
5
+ zarr
6
+ h5netcdf
7
+ netcdf4
8
+ cfgrib
9
+ eccodes
10
+ rioxarray
11
+ cartopy
12
+ pyproj
13
+ matplotlib
14
+ gradio>=4
15
+ numpy
16
+ pandas
17
+ pydantic
18
+ orjson
19
+ pytest
20
+ ruff
tensorview/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """TensorView: A browser-based netCDF/HDF/GRIB/Zarr dataset viewer."""
2
+
3
+ __version__ = "1.0.0"
tensorview/anim.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Animation functionality for creating MP4 videos from multi-dimensional data."""
2
+
3
+ import os
4
+ import tempfile
5
+ import subprocess
6
+ from typing import Optional, Callable, List
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.animation import FuncAnimation
10
+ import xarray as xr
11
+
12
+ from .plot import plot_1d, plot_2d, plot_map, setup_matplotlib
13
+ from .utils import identify_coordinates, format_value
14
+
15
+
16
+ def check_ffmpeg():
17
+ """Check if FFmpeg is available."""
18
+ try:
19
+ subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
20
+ return True
21
+ except (subprocess.CalledProcessError, FileNotFoundError):
22
+ return False
23
+
24
+
25
+ def animate_over_dim(da: xr.DataArray, dim: str, plot_func: Callable = None,
26
+ fps: int = 10, out: str = "animation.mp4",
27
+ figsize: tuple = (10, 8), **plot_kwargs) -> str:
28
+ """
29
+ Create an animation over a specified dimension.
30
+
31
+ Args:
32
+ da: Input DataArray
33
+ dim: Dimension to animate over
34
+ plot_func: Plotting function to use (auto-detected if None)
35
+ fps: Frames per second
36
+ out: Output file path
37
+ figsize: Figure size
38
+ **plot_kwargs: Additional plotting parameters
39
+
40
+ Returns:
41
+ Path to the created animation file
42
+ """
43
+ if not check_ffmpeg():
44
+ raise RuntimeError("FFmpeg is required for creating MP4 animations")
45
+
46
+ if dim not in da.dims:
47
+ raise ValueError(f"Dimension '{dim}' not found in DataArray")
48
+
49
+ setup_matplotlib()
50
+
51
+ # Get coordinate values for the animation dimension
52
+ coord_vals = da.coords[dim].values
53
+ n_frames = len(coord_vals)
54
+
55
+ if n_frames < 2:
56
+ raise ValueError(f"Need at least 2 frames for animation, got {n_frames}")
57
+
58
+ # Auto-detect plot function if not provided
59
+ if plot_func is None:
60
+ remaining_dims = [d for d in da.dims if d != dim]
61
+ n_remaining = len(remaining_dims)
62
+
63
+ # Check if we have geographic coordinates
64
+ coords = identify_coordinates(da)
65
+ has_geo = 'X' in coords and 'Y' in coords
66
+
67
+ if n_remaining == 1:
68
+ plot_func = plot_1d
69
+ elif n_remaining == 2 and has_geo:
70
+ plot_func = plot_map
71
+ elif n_remaining == 2:
72
+ plot_func = plot_2d
73
+ else:
74
+ raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data")
75
+
76
+ # Create figure and initial plot
77
+ fig, ax = plt.subplots(figsize=figsize)
78
+
79
+ # Get initial frame
80
+ initial_frame = da.isel({dim: 0})
81
+
82
+ # Set up consistent color limits across all frames
83
+ if 'vmin' not in plot_kwargs:
84
+ plot_kwargs['vmin'] = float(da.min().values)
85
+ if 'vmax' not in plot_kwargs:
86
+ plot_kwargs['vmax'] = float(da.max().values)
87
+
88
+ # Create initial plot to get the structure
89
+ if plot_func == plot_1d:
90
+ line, = ax.plot([], [])
91
+ ax.set_xlim(float(initial_frame.coords[initial_frame.dims[0]].min()),
92
+ float(initial_frame.coords[initial_frame.dims[0]].max()))
93
+ ax.set_ylim(plot_kwargs['vmin'], plot_kwargs['vmax'])
94
+
95
+ # Set labels
96
+ x_dim = initial_frame.dims[0]
97
+ ax.set_xlabel(f"{x_dim} ({initial_frame.coords[x_dim].attrs.get('units', '')})")
98
+ ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
99
+
100
+ def animate(frame_idx):
101
+ frame_data = da.isel({dim: frame_idx})
102
+ x_data = frame_data.coords[x_dim]
103
+ line.set_data(x_data, frame_data)
104
+
105
+ # Update title with current time/coordinate value
106
+ coord_val = coord_vals[frame_idx]
107
+ coord_str = format_value(coord_val, dim)
108
+ title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}"
109
+ ax.set_title(title)
110
+
111
+ return line,
112
+
113
+ elif plot_func in [plot_2d, plot_map]:
114
+ # For 2D plots, we need to recreate the plot each frame
115
+ def animate(frame_idx):
116
+ ax.clear()
117
+ frame_data = da.isel({dim: frame_idx})
118
+
119
+ # Create the plot
120
+ if plot_func == plot_map:
121
+ # Special handling for map plots
122
+ import cartopy.crs as ccrs
123
+ import cartopy.feature as cfeature
124
+
125
+ proj = plot_kwargs.get('proj', 'PlateCarree')
126
+ proj_map = {
127
+ 'PlateCarree': ccrs.PlateCarree(),
128
+ 'Robinson': ccrs.Robinson(),
129
+ 'Mollweide': ccrs.Mollweide()
130
+ }
131
+ projection = proj_map.get(proj, ccrs.PlateCarree())
132
+
133
+ coords = identify_coordinates(frame_data)
134
+ lon_dim = coords['X']
135
+ lat_dim = coords['Y']
136
+
137
+ lons = frame_data.coords[lon_dim].values
138
+ lats = frame_data.coords[lat_dim].values
139
+
140
+ # Create pcolormesh plot
141
+ cmap = plot_kwargs.get('cmap', 'viridis')
142
+ im = ax.pcolormesh(lons, lats, frame_data.transpose(lat_dim, lon_dim).values,
143
+ cmap=cmap, vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'],
144
+ transform=ccrs.PlateCarree(), shading='auto')
145
+
146
+ # Add map features
147
+ if plot_kwargs.get('coastlines', True):
148
+ ax.coastlines(resolution='50m', color='black', linewidth=0.5)
149
+ if plot_kwargs.get('gridlines', True):
150
+ ax.gridlines(alpha=0.5)
151
+
152
+ ax.set_global()
153
+
154
+ else:
155
+ # Regular 2D plot
156
+ coords = identify_coordinates(frame_data)
157
+ x_dim = coords.get('X', frame_data.dims[-1])
158
+ y_dim = coords.get('Y', frame_data.dims[-2])
159
+
160
+ frame_plot = frame_data.transpose(y_dim, x_dim)
161
+ x_coord = frame_data.coords[x_dim]
162
+ y_coord = frame_data.coords[y_dim]
163
+
164
+ im = ax.imshow(frame_plot.values,
165
+ extent=[float(x_coord.min()), float(x_coord.max()),
166
+ float(y_coord.min()), float(y_coord.max())],
167
+ aspect='auto', origin='lower',
168
+ cmap=plot_kwargs.get('cmap', 'viridis'),
169
+ vmin=plot_kwargs['vmin'], vmax=plot_kwargs['vmax'])
170
+
171
+ ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})")
172
+ ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})")
173
+
174
+ # Update title
175
+ coord_val = coord_vals[frame_idx]
176
+ coord_str = format_value(coord_val, dim)
177
+ title = f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}"
178
+ ax.set_title(title)
179
+
180
+ return [im] if 'im' in locals() else []
181
+
182
+ # Create animation
183
+ anim = FuncAnimation(fig, animate, frames=n_frames, interval=1000//fps, blit=False)
184
+
185
+ # Save animation
186
+ try:
187
+ # Use FFmpeg writer
188
+ Writer = plt.matplotlib.animation.writers['ffmpeg']
189
+ writer = Writer(fps=fps, metadata=dict(artist='TensorView'), bitrate=1800)
190
+ anim.save(out, writer=writer)
191
+
192
+ plt.close(fig)
193
+ return out
194
+
195
+ except Exception as e:
196
+ plt.close(fig)
197
+ raise RuntimeError(f"Failed to create animation: {str(e)}")
198
+
199
+
200
+ def create_frame_sequence(da: xr.DataArray, dim: str, plot_func: Callable = None,
201
+ output_dir: str = "frames", **plot_kwargs) -> List[str]:
202
+ """
203
+ Create a sequence of individual frame images.
204
+
205
+ Args:
206
+ da: Input DataArray
207
+ dim: Dimension to animate over
208
+ plot_func: Plotting function to use
209
+ output_dir: Directory to save frames
210
+ **plot_kwargs: Additional plotting parameters
211
+
212
+ Returns:
213
+ List of frame file paths
214
+ """
215
+ if dim not in da.dims:
216
+ raise ValueError(f"Dimension '{dim}' not found in DataArray")
217
+
218
+ os.makedirs(output_dir, exist_ok=True)
219
+
220
+ coord_vals = da.coords[dim].values
221
+ frame_paths = []
222
+
223
+ # Auto-detect plot function if not provided
224
+ if plot_func is None:
225
+ remaining_dims = [d for d in da.dims if d != dim]
226
+ n_remaining = len(remaining_dims)
227
+
228
+ coords = identify_coordinates(da)
229
+ has_geo = 'X' in coords and 'Y' in coords
230
+
231
+ if n_remaining == 1:
232
+ plot_func = plot_1d
233
+ elif n_remaining == 2 and has_geo:
234
+ plot_func = plot_map
235
+ elif n_remaining == 2:
236
+ plot_func = plot_2d
237
+ else:
238
+ raise ValueError(f"Cannot auto-detect plot type for {n_remaining}D data")
239
+
240
+ # Set consistent color limits
241
+ if 'vmin' not in plot_kwargs:
242
+ plot_kwargs['vmin'] = float(da.min().values)
243
+ if 'vmax' not in plot_kwargs:
244
+ plot_kwargs['vmax'] = float(da.max().values)
245
+
246
+ # Create frames
247
+ for i, coord_val in enumerate(coord_vals):
248
+ frame_data = da.isel({dim: i})
249
+
250
+ # Create plot
251
+ fig = plot_func(frame_data, **plot_kwargs)
252
+
253
+ # Update title with coordinate value
254
+ coord_str = format_value(coord_val, dim)
255
+ fig.suptitle(f"{da.attrs.get('long_name', da.name or 'Data')} - {dim}={coord_str}")
256
+
257
+ # Save frame
258
+ frame_path = os.path.join(output_dir, f"frame_{i:04d}.png")
259
+ fig.savefig(frame_path, dpi=150, bbox_inches='tight')
260
+ frame_paths.append(frame_path)
261
+
262
+ plt.close(fig)
263
+
264
+ return frame_paths
265
+
266
+
267
+ def frames_to_mp4(frame_dir: str, output_path: str, fps: int = 10, cleanup: bool = True) -> str:
268
+ """
269
+ Convert a directory of frame images to MP4 video.
270
+
271
+ Args:
272
+ frame_dir: Directory containing frame images
273
+ output_path: Output MP4 file path
274
+ fps: Frames per second
275
+ cleanup: Whether to delete frame files after conversion
276
+
277
+ Returns:
278
+ Path to created MP4 file
279
+ """
280
+ if not check_ffmpeg():
281
+ raise RuntimeError("FFmpeg is required for MP4 conversion")
282
+
283
+ # Build FFmpeg command
284
+ cmd = [
285
+ 'ffmpeg', '-y', # Overwrite output
286
+ '-framerate', str(fps),
287
+ '-pattern_type', 'glob',
288
+ '-i', os.path.join(frame_dir, 'frame_*.png'),
289
+ '-c:v', 'libx264',
290
+ '-pix_fmt', 'yuv420p',
291
+ '-crf', '18', # High quality
292
+ output_path
293
+ ]
294
+
295
+ try:
296
+ subprocess.run(cmd, check=True, capture_output=True)
297
+
298
+ # Clean up frame files if requested
299
+ if cleanup:
300
+ import glob
301
+ for frame_file in glob.glob(os.path.join(frame_dir, 'frame_*.png')):
302
+ os.remove(frame_file)
303
+
304
+ # Remove directory if empty
305
+ try:
306
+ os.rmdir(frame_dir)
307
+ except OSError:
308
+ pass # Directory not empty
309
+
310
+ return output_path
311
+
312
+ except subprocess.CalledProcessError as e:
313
+ raise RuntimeError(f"FFmpeg failed: {e.stderr.decode()}")
314
+
315
+
316
+ def create_gif(da: xr.DataArray, dim: str, output_path: str = "animation.gif",
317
+ duration: int = 200, plot_func: Callable = None, **plot_kwargs) -> str:
318
+ """
319
+ Create an animated GIF.
320
+
321
+ Args:
322
+ da: Input DataArray
323
+ dim: Dimension to animate over
324
+ output_path: Output GIF file path
325
+ duration: Duration per frame in milliseconds
326
+ plot_func: Plotting function to use
327
+ **plot_kwargs: Additional plotting parameters
328
+
329
+ Returns:
330
+ Path to created GIF file
331
+ """
332
+ try:
333
+ from PIL import Image
334
+ except ImportError:
335
+ raise ImportError("Pillow is required for GIF creation")
336
+
337
+ # Create frame sequence
338
+ with tempfile.TemporaryDirectory() as temp_dir:
339
+ frame_paths = create_frame_sequence(da, dim, plot_func, temp_dir, **plot_kwargs)
340
+
341
+ # Load frames and create GIF
342
+ images = []
343
+ for frame_path in frame_paths:
344
+ img = Image.open(frame_path)
345
+ images.append(img)
346
+
347
+ # Save as GIF
348
+ images[0].save(output_path, save_all=True, append_images=images[1:],
349
+ duration=duration, loop=0)
350
+
351
+ return output_path
tensorview/colors.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Colormap loading and management for CPT/ACT/RGB files."""
2
+
3
+ import os
4
+ import re
5
+ from typing import List, Tuple, Optional
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.colors as mcolors
9
+ from matplotlib.colors import LinearSegmentedColormap, ListedColormap
10
+
11
+
12
+ def parse_cpt_file(filepath: str) -> LinearSegmentedColormap:
13
+ """
14
+ Load a GMT-style CPT (Color Palette Table) file.
15
+
16
+ Args:
17
+ filepath: Path to the CPT file
18
+
19
+ Returns:
20
+ matplotlib LinearSegmentedColormap
21
+ """
22
+ colors = []
23
+ positions = []
24
+
25
+ with open(filepath, 'r') as f:
26
+ lines = f.readlines()
27
+
28
+ # Parse CPT format
29
+ for line in lines:
30
+ line = line.strip()
31
+
32
+ # Skip comments and empty lines
33
+ if line.startswith('#') or not line or line.startswith('B') or line.startswith('F') or line.startswith('N'):
34
+ continue
35
+
36
+ parts = line.split()
37
+ if len(parts) >= 8:
38
+ # CPT format: z0 r0 g0 b0 z1 r1 g1 b1
39
+ try:
40
+ z0, r0, g0, b0, z1, r1, g1, b1 = map(float, parts[:8])
41
+
42
+ # Normalize RGB values if they're in 0-255 range
43
+ if r0 > 1 or g0 > 1 or b0 > 1:
44
+ r0, g0, b0 = r0/255, g0/255, b0/255
45
+ if r1 > 1 or g1 > 1 or b1 > 1:
46
+ r1, g1, b1 = r1/255, g1/255, b1/255
47
+
48
+ colors.extend([(r0, g0, b0), (r1, g1, b1)])
49
+ positions.extend([z0, z1])
50
+
51
+ except ValueError:
52
+ continue
53
+
54
+ if not colors:
55
+ raise ValueError(f"No valid color data found in {filepath}")
56
+
57
+ # Normalize positions to 0-1 range
58
+ positions = np.array(positions)
59
+ if len(set(positions)) > 1:
60
+ positions = (positions - positions.min()) / (positions.max() - positions.min())
61
+ else:
62
+ positions = np.linspace(0, 1, len(positions))
63
+
64
+ # Create colormap
65
+ cmap_data = list(zip(positions, colors))
66
+ cmap_data.sort(key=lambda x: x[0]) # Sort by position
67
+
68
+ # Remove duplicates
69
+ unique_data = []
70
+ seen_positions = set()
71
+ for pos, color in cmap_data:
72
+ if pos not in seen_positions:
73
+ unique_data.append((pos, color))
74
+ seen_positions.add(pos)
75
+
76
+ positions, colors = zip(*unique_data)
77
+
78
+ # Create the colormap
79
+ name = os.path.splitext(os.path.basename(filepath))[0]
80
+ return LinearSegmentedColormap.from_list(name, list(zip(positions, colors)))
81
+
82
+
83
+ def parse_act_file(filepath: str) -> ListedColormap:
84
+ """
85
+ Load an Adobe Color Table (ACT) file.
86
+
87
+ Args:
88
+ filepath: Path to the ACT file
89
+
90
+ Returns:
91
+ matplotlib ListedColormap
92
+ """
93
+ with open(filepath, 'rb') as f:
94
+ data = f.read()
95
+
96
+ # ACT files are 768 bytes (256 colors * 3 RGB values)
97
+ if len(data) != 768:
98
+ raise ValueError(f"Invalid ACT file size: {len(data)} bytes (expected 768)")
99
+
100
+ colors = []
101
+ for i in range(0, 768, 3):
102
+ r, g, b = data[i], data[i+1], data[i+2]
103
+ colors.append((r/255.0, g/255.0, b/255.0))
104
+
105
+ name = os.path.splitext(os.path.basename(filepath))[0]
106
+ return ListedColormap(colors, name=name)
107
+
108
+
109
+ def parse_rgb_file(filepath: str) -> ListedColormap:
110
+ """
111
+ Load a simple RGB text file (one color per line).
112
+
113
+ Args:
114
+ filepath: Path to the RGB file
115
+
116
+ Returns:
117
+ matplotlib ListedColormap
118
+ """
119
+ colors = []
120
+
121
+ with open(filepath, 'r') as f:
122
+ lines = f.readlines()
123
+
124
+ for line in lines:
125
+ line = line.strip()
126
+
127
+ # Skip comments and empty lines
128
+ if line.startswith('#') or not line:
129
+ continue
130
+
131
+ # Parse RGB values
132
+ parts = line.split()
133
+ if len(parts) >= 3:
134
+ try:
135
+ r, g, b = map(float, parts[:3])
136
+
137
+ # Normalize RGB values if they're in 0-255 range
138
+ if r > 1 or g > 1 or b > 1:
139
+ r, g, b = r/255, g/255, b/255
140
+
141
+ colors.append((r, g, b))
142
+ except ValueError:
143
+ continue
144
+
145
+ if not colors:
146
+ raise ValueError(f"No valid color data found in {filepath}")
147
+
148
+ name = os.path.splitext(os.path.basename(filepath))[0]
149
+ return ListedColormap(colors, name=name)
150
+
151
+
152
+ def load_cpt(filepath: str) -> LinearSegmentedColormap:
153
+ """Load a CPT colormap file."""
154
+ return parse_cpt_file(filepath)
155
+
156
+
157
+ def load_act(filepath: str) -> ListedColormap:
158
+ """Load an ACT colormap file."""
159
+ return parse_act_file(filepath)
160
+
161
+
162
+ def load_rgb(filepath: str) -> ListedColormap:
163
+ """Load an RGB colormap file."""
164
+ return parse_rgb_file(filepath)
165
+
166
+
167
+ def load_colormap(filepath: str) -> mcolors.Colormap:
168
+ """
169
+ Load a colormap from file, auto-detecting the format.
170
+
171
+ Args:
172
+ filepath: Path to the colormap file
173
+
174
+ Returns:
175
+ matplotlib Colormap
176
+ """
177
+ _, ext = os.path.splitext(filepath.lower())
178
+
179
+ if ext == '.cpt':
180
+ return load_cpt(filepath)
181
+ elif ext == '.act':
182
+ return load_act(filepath)
183
+ elif ext in ['.rgb', '.txt']:
184
+ return load_rgb(filepath)
185
+ else:
186
+ # Try to guess based on content
187
+ try:
188
+ return load_cpt(filepath)
189
+ except:
190
+ try:
191
+ return load_rgb(filepath)
192
+ except:
193
+ raise ValueError(f"Cannot determine colormap format for {filepath}")
194
+
195
+
196
+ def get_matplotlib_colormaps() -> List[str]:
197
+ """Get list of available matplotlib colormaps."""
198
+ return sorted(plt.colormaps())
199
+
200
+
201
+ def create_diverging_colormap(name: str, colors: List[str], center: float = 0.5) -> LinearSegmentedColormap:
202
+ """
203
+ Create a diverging colormap.
204
+
205
+ Args:
206
+ name: Name for the colormap
207
+ colors: List of color names/hex codes
208
+ center: Position of the center color (0-1)
209
+
210
+ Returns:
211
+ LinearSegmentedColormap
212
+ """
213
+ return LinearSegmentedColormap.from_list(name, colors)
214
+
215
+
216
+ def reverse_colormap(cmap: mcolors.Colormap) -> mcolors.Colormap:
217
+ """Reverse a colormap."""
218
+ if hasattr(cmap, 'reversed'):
219
+ return cmap.reversed()
220
+ else:
221
+ # For custom colormaps
222
+ if isinstance(cmap, LinearSegmentedColormap):
223
+ return LinearSegmentedColormap.from_list(
224
+ f"{cmap.name}_r",
225
+ cmap(np.linspace(0, 1, 256))[::-1]
226
+ )
227
+ elif isinstance(cmap, ListedColormap):
228
+ return ListedColormap(
229
+ cmap.colors[::-1],
230
+ name=f"{cmap.name}_r"
231
+ )
232
+ else:
233
+ return cmap
234
+
235
+
236
+ def colormap_to_cpt(cmap: mcolors.Colormap, filepath: str, n_colors: int = 256):
237
+ """
238
+ Save a matplotlib colormap as a CPT file.
239
+
240
+ Args:
241
+ cmap: matplotlib Colormap
242
+ filepath: Output file path
243
+ n_colors: Number of color levels
244
+ """
245
+ colors = cmap(np.linspace(0, 1, n_colors))
246
+
247
+ with open(filepath, 'w') as f:
248
+ f.write(f"# Color palette: {cmap.name}\n")
249
+ f.write("# Created by TensorView\n")
250
+
251
+ for i in range(len(colors) - 1):
252
+ r0, g0, b0 = colors[i][:3]
253
+ r1, g1, b1 = colors[i+1][:3]
254
+
255
+ z0 = i / (n_colors - 1)
256
+ z1 = (i + 1) / (n_colors - 1)
257
+
258
+ f.write(f"{z0:.6f}\t{r0*255:.0f}\t{g0*255:.0f}\t{b0*255:.0f}\t")
259
+ f.write(f"{z1:.6f}\t{r1*255:.0f}\t{g1*255:.0f}\t{b1*255:.0f}\n")
260
+
261
+
262
+ def get_colormap_info(cmap: mcolors.Colormap) -> dict:
263
+ """Get information about a colormap."""
264
+ info = {
265
+ 'name': getattr(cmap, 'name', 'unknown'),
266
+ 'type': type(cmap).__name__,
267
+ 'n_colors': getattr(cmap, 'N', 'continuous')
268
+ }
269
+
270
+ if isinstance(cmap, ListedColormap):
271
+ info['n_colors'] = len(cmap.colors)
272
+
273
+ return info
tensorview/grid.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grid alignment and combination operations."""
2
+
3
+ from typing import Literal, Dict, Any, Tuple
4
+ import numpy as np
5
+ import xarray as xr
6
+ from .utils import identify_coordinates, get_crs, is_geographic
7
+
8
+
9
+ def align_for_combine(a: xr.DataArray, b: xr.DataArray, method: str = "reindex") -> Tuple[xr.DataArray, xr.DataArray]:
10
+ """
11
+ Align two DataArrays for combination operations.
12
+
13
+ Args:
14
+ a, b: Input DataArrays
15
+ method: Alignment method ('reindex', 'interp')
16
+
17
+ Returns:
18
+ Tuple of aligned DataArrays
19
+ """
20
+ # Check CRS compatibility
21
+ crs_a = get_crs(a)
22
+ crs_b = get_crs(b)
23
+
24
+ if crs_a and crs_b and not crs_a.equals(crs_b):
25
+ raise ValueError(f"CRS mismatch: {crs_a} vs {crs_b}")
26
+
27
+ # Get coordinate information
28
+ coords_a = identify_coordinates(a)
29
+ coords_b = identify_coordinates(b)
30
+
31
+ # Find common dimensions
32
+ common_dims = set(a.dims) & set(b.dims)
33
+
34
+ if not common_dims:
35
+ raise ValueError("No common dimensions found for alignment")
36
+
37
+ # Align coordinates
38
+ if method == "reindex":
39
+ # Use nearest neighbor reindexing
40
+ a_aligned = a
41
+ b_aligned = b
42
+
43
+ for dim in common_dims:
44
+ if dim in a.dims and dim in b.dims:
45
+ # Get the union of coordinates for this dimension
46
+ coord_a = a.coords[dim]
47
+ coord_b = b.coords[dim]
48
+
49
+ # Use the coordinate with higher resolution
50
+ if len(coord_a) >= len(coord_b):
51
+ target_coord = coord_a
52
+ else:
53
+ target_coord = coord_b
54
+
55
+ # Reindex both arrays to the target coordinate
56
+ a_aligned = a_aligned.reindex({dim: target_coord}, method='nearest')
57
+ b_aligned = b_aligned.reindex({dim: target_coord}, method='nearest')
58
+
59
+ elif method == "interp":
60
+ # Use interpolation
61
+ # Find common coordinate grid
62
+ common_coords = {}
63
+ for dim in common_dims:
64
+ if dim in a.dims and dim in b.dims:
65
+ coord_a = a.coords[dim]
66
+ coord_b = b.coords[dim]
67
+
68
+ # Create a common grid (intersection)
69
+ min_val = max(float(coord_a.min()), float(coord_b.min()))
70
+ max_val = min(float(coord_a.max()), float(coord_b.max()))
71
+
72
+ # Use the finer resolution
73
+ res_a = float(coord_a[1] - coord_a[0]) if len(coord_a) > 1 else 1.0
74
+ res_b = float(coord_b[1] - coord_b[0]) if len(coord_b) > 1 else 1.0
75
+ res = min(abs(res_a), abs(res_b))
76
+
77
+ common_coords[dim] = np.arange(min_val, max_val + res, res)
78
+
79
+ a_aligned = a.interp(common_coords)
80
+ b_aligned = b.interp(common_coords)
81
+
82
+ else:
83
+ raise ValueError(f"Unknown alignment method: {method}")
84
+
85
+ return a_aligned, b_aligned
86
+
87
+
88
+ def combine(a: xr.DataArray, b: xr.DataArray, op: Literal["sum", "avg", "diff"] = "sum") -> xr.DataArray:
89
+ """
90
+ Combine two DataArrays with the specified operation.
91
+
92
+ Args:
93
+ a, b: Input DataArrays
94
+ op: Operation ('sum', 'avg', 'diff')
95
+
96
+ Returns:
97
+ Combined DataArray
98
+ """
99
+ # Align the arrays first
100
+ a_aligned, b_aligned = align_for_combine(a, b)
101
+
102
+ # Perform the operation
103
+ if op == "sum":
104
+ result = a_aligned + b_aligned
105
+ elif op == "avg":
106
+ result = (a_aligned + b_aligned) / 2
107
+ elif op == "diff":
108
+ result = a_aligned - b_aligned
109
+ else:
110
+ raise ValueError(f"Unknown operation: {op}")
111
+
112
+ # Update attributes
113
+ result.name = f"{a.name}_{op}_{b.name}"
114
+
115
+ if op == "sum":
116
+ result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} + {b.attrs.get('long_name', b.name)}"
117
+ elif op == "avg":
118
+ result.attrs['long_name'] = f"Average of {a.attrs.get('long_name', a.name)} and {b.attrs.get('long_name', b.name)}"
119
+ elif op == "diff":
120
+ result.attrs['long_name'] = f"{a.attrs.get('long_name', a.name)} - {b.attrs.get('long_name', b.name)}"
121
+
122
+ # Preserve units if they match
123
+ if a.attrs.get('units') == b.attrs.get('units'):
124
+ result.attrs['units'] = a.attrs.get('units', '')
125
+
126
+ return result
127
+
128
+
129
+ def section(da: xr.DataArray, along: str, fixed: Dict[str, Any]) -> xr.DataArray:
130
+ """
131
+ Create a cross-section of the DataArray.
132
+
133
+ Args:
134
+ da: Input DataArray
135
+ along: Dimension to keep for the section (e.g., 'time', 'lat')
136
+ fixed: Dictionary of {dim: value} for dimensions to fix
137
+
138
+ Returns:
139
+ Cross-section DataArray
140
+ """
141
+ if along not in da.dims:
142
+ raise ValueError(f"Dimension '{along}' not found in DataArray")
143
+
144
+ # Start with the full array
145
+ result = da
146
+
147
+ # Apply fixed selections
148
+ selection = {}
149
+ for dim, value in fixed.items():
150
+ if dim not in da.dims:
151
+ continue
152
+
153
+ coord = da.coords[dim]
154
+
155
+ if isinstance(value, (int, float)):
156
+ # Select nearest value
157
+ selection[dim] = coord.sel({dim: value}, method='nearest')
158
+ elif isinstance(value, str) and 'time' in dim.lower():
159
+ # Handle time strings
160
+ selection[dim] = value
161
+ else:
162
+ selection[dim] = value
163
+
164
+ if selection:
165
+ result = result.sel(selection, method='nearest')
166
+
167
+ # Ensure the 'along' dimension is preserved
168
+ if along not in result.dims:
169
+ raise ValueError(f"Section operation removed the '{along}' dimension")
170
+
171
+ # Update metadata
172
+ result.attrs = da.attrs.copy()
173
+
174
+ # Add section info to long_name
175
+ section_info = []
176
+ for dim, value in fixed.items():
177
+ if dim in da.dims:
178
+ if isinstance(value, (int, float)):
179
+ section_info.append(f"{dim}={value:.3f}")
180
+ else:
181
+ section_info.append(f"{dim}={value}")
182
+
183
+ if section_info:
184
+ long_name = result.attrs.get('long_name', result.name)
185
+ result.attrs['long_name'] = f"{long_name} ({', '.join(section_info)})"
186
+
187
+ return result
188
+
189
+
190
+ def aggregate_spatial(da: xr.DataArray, method: str = "mean") -> xr.DataArray:
191
+ """
192
+ Aggregate spatially (e.g., zonal mean).
193
+
194
+ Args:
195
+ da: Input DataArray
196
+ method: Aggregation method ('mean', 'sum', 'std')
197
+
198
+ Returns:
199
+ Spatially aggregated DataArray
200
+ """
201
+ coords = identify_coordinates(da)
202
+
203
+ spatial_dims = []
204
+ if 'X' in coords:
205
+ spatial_dims.append(coords['X'])
206
+ if 'Y' in coords:
207
+ spatial_dims.append(coords['Y'])
208
+
209
+ if not spatial_dims:
210
+ raise ValueError("No spatial dimensions found for aggregation")
211
+
212
+ # Perform aggregation
213
+ if method == "mean":
214
+ result = da.mean(dim=spatial_dims)
215
+ elif method == "sum":
216
+ result = da.sum(dim=spatial_dims)
217
+ elif method == "std":
218
+ result = da.std(dim=spatial_dims)
219
+ else:
220
+ raise ValueError(f"Unknown aggregation method: {method}")
221
+
222
+ # Update attributes
223
+ result.attrs = da.attrs.copy()
224
+ long_name = result.attrs.get('long_name', result.name)
225
+ result.attrs['long_name'] = f"{method.capitalize()} of {long_name}"
226
+
227
+ return result
tensorview/io.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """I/O operations for opening and managing datasets."""
2
+
3
+ import os
4
+ from typing import Optional, Dict, Any, List
5
+ from dataclasses import dataclass
6
+ import xarray as xr
7
+ import fsspec
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+
12
+
13
+ @dataclass
14
+ class VariableSpec:
15
+ """Variable specification with metadata."""
16
+ name: str
17
+ shape: tuple
18
+ dims: tuple
19
+ dtype: str
20
+ units: str
21
+ long_name: str
22
+ attrs: Dict[str, Any]
23
+
24
+
25
+ class DatasetHandle:
26
+ """Handle for opened datasets."""
27
+
28
+ def __init__(self, dataset: xr.Dataset, uri: str, engine: str):
29
+ self.dataset = dataset
30
+ self.uri = uri
31
+ self.engine = engine
32
+
33
+ def close(self):
34
+ """Close the dataset."""
35
+ if hasattr(self.dataset, 'close'):
36
+ self.dataset.close()
37
+
38
+
39
+ def detect_engine(uri: str) -> str:
40
+ """Auto-detect the appropriate engine for a given URI/file."""
41
+ if uri.lower().endswith('.zarr') or 'zarr' in uri.lower():
42
+ return 'zarr'
43
+ elif uri.lower().endswith('.grib') or uri.lower().endswith('.grb'):
44
+ return 'cfgrib'
45
+ elif any(ext in uri.lower() for ext in ['.nc', '.netcdf', '.hdf', '.h5']):
46
+ # Try h5netcdf first, fallback to netcdf4
47
+ try:
48
+ import h5netcdf
49
+ return 'h5netcdf'
50
+ except ImportError:
51
+ return 'netcdf4'
52
+ else:
53
+ # Default fallback
54
+ try:
55
+ import h5netcdf
56
+ return 'h5netcdf'
57
+ except ImportError:
58
+ return 'netcdf4'
59
+
60
+
61
+ def open_any(uri: str, engine: Optional[str] = None, chunks: str = "auto") -> DatasetHandle:
62
+ """
63
+ Open a dataset from various sources (local, HTTP, S3, etc.).
64
+
65
+ Args:
66
+ uri: Path or URL to dataset
67
+ engine: Engine to use ('h5netcdf', 'netcdf4', 'cfgrib', 'zarr')
68
+ chunks: Chunking strategy for dask
69
+
70
+ Returns:
71
+ DatasetHandle: Handle to the opened dataset
72
+ """
73
+ if engine is None:
74
+ engine = detect_engine(uri)
75
+
76
+ try:
77
+ if engine == 'zarr':
78
+ # For Zarr stores
79
+ if uri.startswith('s3://'):
80
+ import s3fs
81
+ fs = s3fs.S3FileSystem(anon=True)
82
+ store = s3fs.S3Map(root=uri, s3=fs, check=False)
83
+ ds = xr.open_zarr(store, chunks=chunks)
84
+ else:
85
+ ds = xr.open_zarr(uri, chunks=chunks)
86
+ elif engine == 'cfgrib':
87
+ # For GRIB files
88
+ ds = xr.open_dataset(uri, engine='cfgrib', chunks=chunks)
89
+ else:
90
+ # For netCDF/HDF files
91
+ if uri.startswith(('http://', 'https://')):
92
+ # Remote files (including OPeNDAP)
93
+ ds = xr.open_dataset(uri, engine=engine, chunks=chunks)
94
+ elif uri.startswith('s3://'):
95
+ # S3 files
96
+ import s3fs
97
+ fs = s3fs.S3FileSystem(anon=True)
98
+ with fs.open(uri, 'rb') as f:
99
+ ds = xr.open_dataset(f, engine=engine, chunks=chunks)
100
+ else:
101
+ # Local files
102
+ ds = xr.open_dataset(uri, engine=engine, chunks=chunks)
103
+
104
+ return DatasetHandle(ds, uri, engine)
105
+
106
+ except Exception as e:
107
+ raise RuntimeError(f"Failed to open {uri} with engine {engine}: {str(e)}")
108
+
109
+
110
+ def list_variables(handle: DatasetHandle) -> List[VariableSpec]:
111
+ """
112
+ List all data variables in the dataset with their specifications.
113
+
114
+ Args:
115
+ handle: Dataset handle
116
+
117
+ Returns:
118
+ List of VariableSpec objects
119
+ """
120
+ variables = []
121
+
122
+ for var_name, var in handle.dataset.data_vars.items():
123
+ # Skip coordinate variables and bounds
124
+ if var_name.endswith('_bounds') or var_name in handle.dataset.coords:
125
+ continue
126
+
127
+ attrs = dict(var.attrs)
128
+
129
+ spec = VariableSpec(
130
+ name=var_name,
131
+ shape=var.shape,
132
+ dims=var.dims,
133
+ dtype=str(var.dtype),
134
+ units=attrs.get('units', ''),
135
+ long_name=attrs.get('long_name', var_name),
136
+ attrs=attrs
137
+ )
138
+ variables.append(spec)
139
+
140
+ return variables
141
+
142
+
143
+ def get_dataarray(handle: DatasetHandle, var: str) -> xr.DataArray:
144
+ """
145
+ Get a specific data array from the dataset.
146
+
147
+ Args:
148
+ handle: Dataset handle
149
+ var: Variable name
150
+
151
+ Returns:
152
+ xarray DataArray
153
+ """
154
+ if var not in handle.dataset.data_vars:
155
+ raise ValueError(f"Variable '{var}' not found in dataset")
156
+
157
+ return handle.dataset[var]
158
+
159
+
160
+ def close(handle: DatasetHandle):
161
+ """Close a dataset handle."""
162
+ handle.close()
tensorview/plot.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plotting functions for 1D, 2D, and map visualizations."""
2
+
3
+ import io
4
+ import os
5
+ from typing import Optional, Dict, Any, Tuple, Literal
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.colors as mcolors
9
+ from matplotlib.figure import Figure
10
+ from matplotlib.axes import Axes
11
+ import xarray as xr
12
+
13
+ try:
14
+ import cartopy.crs as ccrs
15
+ import cartopy.feature as cfeature
16
+ HAS_CARTOPY = True
17
+ except ImportError:
18
+ HAS_CARTOPY = False
19
+
20
+ from .utils import identify_coordinates, get_crs, is_geographic, format_value
21
+
22
+
23
+ def setup_matplotlib():
24
+ """Setup matplotlib with non-interactive backend."""
25
+ plt.switch_backend('Agg')
26
+ plt.style.use('default')
27
+
28
+
29
+ def plot_1d(da: xr.DataArray, x_dim: Optional[str] = None, **style) -> Figure:
30
+ """
31
+ Create a 1D line plot.
32
+
33
+ Args:
34
+ da: Input DataArray (should be 1D or have only one varying dimension)
35
+ x_dim: Dimension to use as x-axis (auto-detected if None)
36
+ **style: Style parameters (color, linewidth, etc.)
37
+
38
+ Returns:
39
+ matplotlib Figure
40
+ """
41
+ setup_matplotlib()
42
+
43
+ # Find the appropriate dimension for x-axis
44
+ if x_dim is None:
45
+ # Find the first dimension with more than 1 element
46
+ for dim in da.dims:
47
+ if da.sizes[dim] > 1:
48
+ x_dim = dim
49
+ break
50
+
51
+ if x_dim is None:
52
+ raise ValueError("No suitable dimension found for 1D plot")
53
+
54
+ if x_dim not in da.dims:
55
+ raise ValueError(f"Dimension '{x_dim}' not found in DataArray")
56
+
57
+ # Create the figure
58
+ fig, ax = plt.subplots(figsize=(10, 6))
59
+
60
+ # Get data for plotting
61
+ x_data = da.coords[x_dim]
62
+ y_data = da
63
+
64
+ # Plot the data
65
+ line_style = {
66
+ 'color': style.get('color', 'blue'),
67
+ 'linewidth': style.get('linewidth', 1.5),
68
+ 'linestyle': style.get('linestyle', '-'),
69
+ 'marker': style.get('marker', ''),
70
+ 'markersize': style.get('markersize', 4),
71
+ 'alpha': style.get('alpha', 1.0)
72
+ }
73
+
74
+ ax.plot(x_data, y_data, **line_style)
75
+
76
+ # Set labels
77
+ ax.set_xlabel(f"{x_dim} ({x_data.attrs.get('units', '')})")
78
+ ax.set_ylabel(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
79
+
80
+ # Set title
81
+ title = da.attrs.get('long_name', da.name or 'Data')
82
+ ax.set_title(title)
83
+
84
+ # Add grid if requested
85
+ if style.get('grid', True):
86
+ ax.grid(True, alpha=0.3)
87
+
88
+ # Handle time axis formatting
89
+ if 'time' in x_dim.lower() or x_data.dtype.kind == 'M':
90
+ fig.autofmt_xdate()
91
+
92
+ plt.tight_layout()
93
+ return fig
94
+
95
+
96
+ def plot_2d(da: xr.DataArray, kind: Literal["image", "contour"] = "image",
97
+ x_dim: Optional[str] = None, y_dim: Optional[str] = None, **style) -> Figure:
98
+ """
99
+ Create a 2D plot (image or contour).
100
+
101
+ Args:
102
+ da: Input DataArray (should be 2D)
103
+ kind: Plot type ('image' or 'contour')
104
+ x_dim, y_dim: Dimensions to use for axes
105
+ **style: Style parameters
106
+
107
+ Returns:
108
+ matplotlib Figure
109
+ """
110
+ setup_matplotlib()
111
+
112
+ # Auto-detect dimensions if not provided
113
+ if x_dim is None or y_dim is None:
114
+ coords = identify_coordinates(da)
115
+ if x_dim is None:
116
+ x_dim = coords.get('X', da.dims[-1]) # Default to last dimension
117
+ if y_dim is None:
118
+ y_dim = coords.get('Y', da.dims[-2]) # Default to second-to-last dimension
119
+
120
+ if x_dim not in da.dims or y_dim not in da.dims:
121
+ raise ValueError(f"Dimensions {x_dim}, {y_dim} not found in DataArray")
122
+
123
+ # Transpose to get (y, x) order for plotting
124
+ da_plot = da.transpose(y_dim, x_dim)
125
+
126
+ # Create figure
127
+ fig, ax = plt.subplots(figsize=(10, 8))
128
+
129
+ # Get coordinates
130
+ x_coord = da.coords[x_dim]
131
+ y_coord = da.coords[y_dim]
132
+
133
+ # Set up colormap
134
+ cmap = style.get('cmap', 'viridis')
135
+ if isinstance(cmap, str):
136
+ cmap = plt.get_cmap(cmap)
137
+
138
+ # Set up normalization
139
+ vmin = style.get('vmin', float(da.min().values))
140
+ vmax = style.get('vmax', float(da.max().values))
141
+ norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
142
+
143
+ if kind == "image":
144
+ # Use imshow for regular grids
145
+ im = ax.imshow(da_plot.values,
146
+ extent=[float(x_coord.min()), float(x_coord.max()),
147
+ float(y_coord.min()), float(y_coord.max())],
148
+ aspect='auto', origin='lower', cmap=cmap, norm=norm)
149
+
150
+ elif kind == "contour":
151
+ # Use contourf for contour plots
152
+ levels = style.get('levels', 20)
153
+ if isinstance(levels, int):
154
+ levels = np.linspace(vmin, vmax, levels)
155
+
156
+ X, Y = np.meshgrid(x_coord, y_coord)
157
+ im = ax.contourf(X, Y, da_plot.values, levels=levels, cmap=cmap, norm=norm)
158
+
159
+ # Add contour lines if requested
160
+ if style.get('contour_lines', False):
161
+ cs = ax.contour(X, Y, da_plot.values, levels=levels, colors='k', linewidths=0.5)
162
+ ax.clabel(cs, inline=True, fontsize=8)
163
+
164
+ # Add colorbar
165
+ if style.get('colorbar', True):
166
+ cbar = plt.colorbar(im, ax=ax)
167
+ cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
168
+
169
+ # Set labels
170
+ ax.set_xlabel(f"{x_dim} ({x_coord.attrs.get('units', '')})")
171
+ ax.set_ylabel(f"{y_dim} ({y_coord.attrs.get('units', '')})")
172
+
173
+ # Set title
174
+ title = da.attrs.get('long_name', da.name or 'Data')
175
+ ax.set_title(title)
176
+
177
+ plt.tight_layout()
178
+ return fig
179
+
180
+
181
+ def plot_map(da: xr.DataArray, proj: str = "PlateCarree", **style) -> Figure:
182
+ """
183
+ Create a map plot with cartopy.
184
+
185
+ Args:
186
+ da: Input DataArray with geographic coordinates
187
+ proj: Map projection name
188
+ **style: Style parameters
189
+
190
+ Returns:
191
+ matplotlib Figure
192
+ """
193
+ if not HAS_CARTOPY:
194
+ raise ImportError("Cartopy is required for map plotting")
195
+
196
+ setup_matplotlib()
197
+
198
+ # Check if data is geographic
199
+ if not is_geographic(da):
200
+ raise ValueError("DataArray does not appear to have geographic coordinates")
201
+
202
+ # Get coordinate information
203
+ coords = identify_coordinates(da)
204
+ if 'X' not in coords or 'Y' not in coords:
205
+ raise ValueError("Could not identify longitude/latitude coordinates")
206
+
207
+ lon_dim = coords['X']
208
+ lat_dim = coords['Y']
209
+
210
+ # Set up projection
211
+ proj_map = {
212
+ 'PlateCarree': ccrs.PlateCarree(),
213
+ 'Robinson': ccrs.Robinson(),
214
+ 'Mollweide': ccrs.Mollweide(),
215
+ 'Orthographic': ccrs.Orthographic(),
216
+ 'NorthPolarStereo': ccrs.NorthPolarStereo(),
217
+ 'SouthPolarStereo': ccrs.SouthPolarStereo(),
218
+ 'Miller': ccrs.Miller(),
219
+ 'InterruptedGoodeHomolosine': ccrs.InterruptedGoodeHomolosine()
220
+ }
221
+
222
+ if proj not in proj_map:
223
+ proj = 'PlateCarree' # Default fallback
224
+
225
+ projection = proj_map[proj]
226
+
227
+ # Create figure with cartopy
228
+ fig, ax = plt.subplots(figsize=(12, 8),
229
+ subplot_kw={'projection': projection})
230
+
231
+ # Transpose to get (lat, lon) order
232
+ da_plot = da.transpose(lat_dim, lon_dim)
233
+
234
+ # Get coordinates
235
+ lons = da.coords[lon_dim].values
236
+ lats = da.coords[lat_dim].values
237
+
238
+ # Set up colormap and normalization
239
+ cmap = style.get('cmap', 'viridis')
240
+ if isinstance(cmap, str):
241
+ cmap = plt.get_cmap(cmap)
242
+
243
+ vmin = style.get('vmin', float(da.min().values))
244
+ vmax = style.get('vmax', float(da.max().values))
245
+
246
+ # Create plot
247
+ plot_type = style.get('plot_type', 'pcolormesh')
248
+
249
+ if plot_type == 'contourf':
250
+ levels = style.get('levels', 20)
251
+ if isinstance(levels, int):
252
+ levels = np.linspace(vmin, vmax, levels)
253
+ im = ax.contourf(lons, lats, da_plot.values, levels=levels,
254
+ cmap=cmap, transform=ccrs.PlateCarree())
255
+ else:
256
+ im = ax.pcolormesh(lons, lats, da_plot.values, cmap=cmap,
257
+ transform=ccrs.PlateCarree(),
258
+ vmin=vmin, vmax=vmax, shading='auto')
259
+
260
+ # Add map features
261
+ if style.get('coastlines', True):
262
+ ax.coastlines(resolution='50m', color='black', linewidth=0.5)
263
+
264
+ if style.get('borders', False):
265
+ ax.add_feature(cfeature.BORDERS, linewidth=0.5)
266
+
267
+ if style.get('ocean', False):
268
+ ax.add_feature(cfeature.OCEAN, color='lightblue', alpha=0.5)
269
+
270
+ if style.get('land', False):
271
+ ax.add_feature(cfeature.LAND, color='lightgray', alpha=0.5)
272
+
273
+ # Add gridlines
274
+ if style.get('gridlines', True):
275
+ gl = ax.gridlines(draw_labels=True, alpha=0.5)
276
+ gl.top_labels = False
277
+ gl.right_labels = False
278
+
279
+ # Set extent if specified
280
+ if 'extent' in style:
281
+ ax.set_extent(style['extent'], crs=ccrs.PlateCarree())
282
+ else:
283
+ ax.set_global()
284
+
285
+ # Add colorbar
286
+ if style.get('colorbar', True):
287
+ cbar = plt.colorbar(im, ax=ax, orientation='horizontal',
288
+ pad=0.05, shrink=0.8)
289
+ cbar.set_label(f"{da.name or 'Value'} ({da.attrs.get('units', '')})")
290
+
291
+ # Set title
292
+ title = da.attrs.get('long_name', da.name or 'Data')
293
+ ax.set_title(title, pad=20)
294
+
295
+ plt.tight_layout()
296
+ return fig
297
+
298
+
299
+ def export_fig(fig: Figure, fmt: Literal["png", "svg", "pdf"] = "png",
300
+ dpi: int = 150, out_path: Optional[str] = None) -> str:
301
+ """
302
+ Export a figure to file or return as bytes.
303
+
304
+ Args:
305
+ fig: matplotlib Figure
306
+ fmt: Output format
307
+ dpi: Resolution for raster formats
308
+ out_path: Output file path (if None, returns bytes)
309
+
310
+ Returns:
311
+ File path or bytes
312
+ """
313
+ if out_path is None:
314
+ # Return as bytes
315
+ buf = io.BytesIO()
316
+ fig.savefig(buf, format=fmt, dpi=dpi, bbox_inches='tight')
317
+ buf.seek(0)
318
+ return buf.getvalue()
319
+ else:
320
+ # Save to file
321
+ fig.savefig(out_path, format=fmt, dpi=dpi, bbox_inches='tight')
322
+ return out_path
323
+
324
+
325
+ def create_subplot_figure(n_plots: int, ncols: int = 2) -> Tuple[Figure, np.ndarray]:
326
+ """Create a figure with multiple subplots."""
327
+ nrows = (n_plots + ncols - 1) // ncols
328
+ fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 4*nrows))
329
+
330
+ if n_plots == 1:
331
+ axes = np.array([axes])
332
+ elif nrows == 1:
333
+ axes = axes.reshape(1, -1)
334
+
335
+ # Hide unused subplots
336
+ for i in range(n_plots, nrows * ncols):
337
+ axes.flat[i].set_visible(False)
338
+
339
+ return fig, axes
340
+
341
+
342
+ def add_statistics_text(ax: Axes, da: xr.DataArray, x: float = 0.02, y: float = 0.98):
343
+ """Add statistics text to a plot."""
344
+ stats = [
345
+ f"Min: {float(da.min().values):.3g}",
346
+ f"Max: {float(da.max().values):.3g}",
347
+ f"Mean: {float(da.mean().values):.3g}",
348
+ f"Std: {float(da.std().values):.3g}"
349
+ ]
350
+
351
+ text = '\n'.join(stats)
352
+ ax.text(x, y, text, transform=ax.transAxes,
353
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
354
+ verticalalignment='top', fontsize=8)
tensorview/state.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """View state serialization and deserialization for saving/loading plot configurations."""
2
+
3
+ import json
4
+ from typing import Dict, Any, Optional, List
5
+ from dataclasses import dataclass, asdict
6
+ from datetime import datetime
7
+ import orjson
8
+ from pydantic import BaseModel, Field, validator
9
+
10
+
11
+ class PlotConfig(BaseModel):
12
+ """Plot configuration schema."""
13
+ plot_type: str = Field(..., description="Type of plot (1d, 2d, map)")
14
+ x_dim: Optional[str] = Field(None, description="X-axis dimension")
15
+ y_dim: Optional[str] = Field(None, description="Y-axis dimension")
16
+ projection: str = Field("PlateCarree", description="Map projection")
17
+ colormap: str = Field("viridis", description="Colormap name")
18
+ vmin: Optional[float] = Field(None, description="Color scale minimum")
19
+ vmax: Optional[float] = Field(None, description="Color scale maximum")
20
+ levels: Optional[List[float]] = Field(None, description="Contour levels")
21
+ style: Dict[str, Any] = Field(default_factory=dict, description="Additional style parameters")
22
+
23
+ @validator('plot_type')
24
+ def validate_plot_type(cls, v):
25
+ allowed = ['1d', '2d', 'map', 'contour']
26
+ if v not in allowed:
27
+ raise ValueError(f"plot_type must be one of {allowed}")
28
+ return v
29
+
30
+
31
+ class DataConfig(BaseModel):
32
+ """Data configuration schema."""
33
+ uri: str = Field(..., description="Data source URI")
34
+ engine: Optional[str] = Field(None, description="Data engine")
35
+ variable_a: str = Field(..., description="Primary variable")
36
+ variable_b: Optional[str] = Field(None, description="Secondary variable for operations")
37
+ operation: str = Field("none", description="Operation between variables")
38
+ selections: Dict[str, Any] = Field(default_factory=dict, description="Dimension selections")
39
+
40
+ @validator('operation')
41
+ def validate_operation(cls, v):
42
+ allowed = ['none', 'sum', 'avg', 'diff']
43
+ if v not in allowed:
44
+ raise ValueError(f"operation must be one of {allowed}")
45
+ return v
46
+
47
+
48
+ class ViewState(BaseModel):
49
+ """Complete view state schema."""
50
+ version: str = Field("1.0", description="State schema version")
51
+ created: str = Field(default_factory=lambda: datetime.utcnow().isoformat(), description="Creation timestamp")
52
+ title: str = Field("", description="User-defined title")
53
+ description: str = Field("", description="User-defined description")
54
+ data_config: DataConfig = Field(..., description="Data configuration")
55
+ plot_config: PlotConfig = Field(..., description="Plot configuration")
56
+ animation: Optional[Dict[str, Any]] = Field(None, description="Animation settings")
57
+ exports: List[str] = Field(default_factory=list, description="Export history")
58
+
59
+ class Config:
60
+ extra = "allow" # Allow additional fields for extensibility
61
+
62
+
63
+ def create_view_state(data_config: Dict[str, Any], plot_config: Dict[str, Any],
64
+ title: str = "", description: str = "",
65
+ animation: Optional[Dict[str, Any]] = None) -> ViewState:
66
+ """
67
+ Create a new view state object.
68
+
69
+ Args:
70
+ data_config: Data configuration dictionary
71
+ plot_config: Plot configuration dictionary
72
+ title: Optional title
73
+ description: Optional description
74
+ animation: Optional animation settings
75
+
76
+ Returns:
77
+ ViewState object
78
+ """
79
+ data_cfg = DataConfig(**data_config)
80
+ plot_cfg = PlotConfig(**plot_config)
81
+
82
+ return ViewState(
83
+ title=title,
84
+ description=description,
85
+ data_config=data_cfg,
86
+ plot_config=plot_cfg,
87
+ animation=animation
88
+ )
89
+
90
+
91
+ def dump_state(state: ViewState) -> str:
92
+ """
93
+ Serialize a view state to JSON string.
94
+
95
+ Args:
96
+ state: ViewState object
97
+
98
+ Returns:
99
+ JSON string
100
+ """
101
+ # Use orjson for better performance and datetime handling
102
+ return orjson.dumps(state.dict(), option=orjson.OPT_INDENT_2).decode('utf-8')
103
+
104
+
105
+ def load_state(state_json: str) -> ViewState:
106
+ """
107
+ Deserialize a view state from JSON string.
108
+
109
+ Args:
110
+ state_json: JSON string
111
+
112
+ Returns:
113
+ ViewState object
114
+ """
115
+ try:
116
+ data = orjson.loads(state_json)
117
+ return ViewState(**data)
118
+ except Exception as e:
119
+ raise ValueError(f"Failed to parse view state: {str(e)}")
120
+
121
+
122
+ def save_state_file(state: ViewState, filepath: str) -> str:
123
+ """
124
+ Save view state to a file.
125
+
126
+ Args:
127
+ state: ViewState object
128
+ filepath: Output file path
129
+
130
+ Returns:
131
+ File path
132
+ """
133
+ state_json = dump_state(state)
134
+
135
+ with open(filepath, 'w', encoding='utf-8') as f:
136
+ f.write(state_json)
137
+
138
+ return filepath
139
+
140
+
141
+ def load_state_file(filepath: str) -> ViewState:
142
+ """
143
+ Load view state from a file.
144
+
145
+ Args:
146
+ filepath: Input file path
147
+
148
+ Returns:
149
+ ViewState object
150
+ """
151
+ with open(filepath, 'r', encoding='utf-8') as f:
152
+ state_json = f.read()
153
+
154
+ return load_state(state_json)
155
+
156
+
157
+ def merge_states(base_state: ViewState, updates: Dict[str, Any]) -> ViewState:
158
+ """
159
+ Merge updates into a base state.
160
+
161
+ Args:
162
+ base_state: Base ViewState object
163
+ updates: Dictionary of updates
164
+
165
+ Returns:
166
+ New ViewState object with updates applied
167
+ """
168
+ # Convert to dict, apply updates, and create new state
169
+ state_dict = base_state.dict()
170
+
171
+ # Deep merge for nested dictionaries
172
+ def deep_merge(base_dict, update_dict):
173
+ for key, value in update_dict.items():
174
+ if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
175
+ deep_merge(base_dict[key], value)
176
+ else:
177
+ base_dict[key] = value
178
+
179
+ deep_merge(state_dict, updates)
180
+ return ViewState(**state_dict)
181
+
182
+
183
+ def validate_state_compatibility(state: ViewState) -> List[str]:
184
+ """
185
+ Check state compatibility and return any warnings.
186
+
187
+ Args:
188
+ state: ViewState object
189
+
190
+ Returns:
191
+ List of warning messages
192
+ """
193
+ warnings = []
194
+
195
+ # Check version compatibility
196
+ current_version = "1.0"
197
+ if state.version != current_version:
198
+ warnings.append(f"State version {state.version} may not be fully compatible with current version {current_version}")
199
+
200
+ # Check plot type and dimension compatibility
201
+ plot_config = state.plot_config
202
+ data_config = state.data_config
203
+
204
+ if plot_config.plot_type == "map":
205
+ if not plot_config.x_dim or not plot_config.y_dim:
206
+ warnings.append("Map plots require both x_dim and y_dim to be specified")
207
+
208
+ elif plot_config.plot_type == "2d":
209
+ if not plot_config.x_dim or not plot_config.y_dim:
210
+ warnings.append("2D plots require both x_dim and y_dim to be specified")
211
+
212
+ elif plot_config.plot_type == "1d":
213
+ if not plot_config.x_dim:
214
+ warnings.append("1D plots require x_dim to be specified")
215
+
216
+ # Check operation compatibility
217
+ if data_config.operation != "none" and not data_config.variable_b:
218
+ warnings.append(f"Operation '{data_config.operation}' requires variable_b to be specified")
219
+
220
+ return warnings
221
+
222
+
223
+ def create_default_state(uri: str, variable: str) -> ViewState:
224
+ """
225
+ Create a default view state for a given data source and variable.
226
+
227
+ Args:
228
+ uri: Data source URI
229
+ variable: Variable name
230
+
231
+ Returns:
232
+ Default ViewState object
233
+ """
234
+ data_config = {
235
+ 'uri': uri,
236
+ 'variable_a': variable,
237
+ 'operation': 'none',
238
+ 'selections': {}
239
+ }
240
+
241
+ plot_config = {
242
+ 'plot_type': '2d',
243
+ 'colormap': 'viridis',
244
+ 'style': {
245
+ 'colorbar': True,
246
+ 'grid': True
247
+ }
248
+ }
249
+
250
+ return create_view_state(data_config, plot_config, title=f"View of {variable}")
251
+
252
+
253
+ def export_state_summary(state: ViewState) -> Dict[str, Any]:
254
+ """
255
+ Create a human-readable summary of a view state.
256
+
257
+ Args:
258
+ state: ViewState object
259
+
260
+ Returns:
261
+ Summary dictionary
262
+ """
263
+ summary = {
264
+ 'title': state.title or "Untitled View",
265
+ 'created': state.created,
266
+ 'data_source': state.data_config.uri,
267
+ 'primary_variable': state.data_config.variable_a,
268
+ 'plot_type': state.plot_config.plot_type,
269
+ 'has_secondary_variable': state.data_config.variable_b is not None,
270
+ 'operation': state.data_config.operation,
271
+ 'colormap': state.plot_config.colormap,
272
+ 'has_animation': state.animation is not None,
273
+ 'export_count': len(state.exports)
274
+ }
275
+
276
+ # Add dimension info
277
+ selections = state.data_config.selections
278
+ if selections:
279
+ summary['fixed_dimensions'] = list(selections.keys())
280
+ summary['selection_count'] = len(selections)
281
+
282
+ # Add plot-specific info
283
+ if state.plot_config.plot_type == "map":
284
+ summary['projection'] = state.plot_config.projection
285
+
286
+ return summary
tensorview/utils.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for CF conventions and coordinate system helpers."""
2
+
3
+ import re
4
+ from typing import Dict, List, Optional, Tuple, Any
5
+ import numpy as np
6
+ import xarray as xr
7
+ from pyproj import CRS
8
+
9
+
10
+ def guess_cf_axis(da: xr.DataArray, coord_name: str) -> Optional[str]:
11
+ """
12
+ Guess the CF axis type (X, Y, Z, T) for a coordinate.
13
+
14
+ Args:
15
+ da: DataArray containing the coordinate
16
+ coord_name: Name of the coordinate
17
+
18
+ Returns:
19
+ CF axis type ('X', 'Y', 'Z', 'T') or None
20
+ """
21
+ if coord_name not in da.coords:
22
+ return None
23
+
24
+ coord = da.coords[coord_name]
25
+ attrs = coord.attrs
26
+ name_lower = coord_name.lower()
27
+
28
+ # Check explicit axis attribute
29
+ if 'axis' in attrs:
30
+ return attrs['axis'].upper()
31
+
32
+ # Check standard_name
33
+ standard_name = attrs.get('standard_name', '').lower()
34
+ if standard_name in ['longitude', 'projection_x_coordinate']:
35
+ return 'X'
36
+ elif standard_name in ['latitude', 'projection_y_coordinate']:
37
+ return 'Y'
38
+ elif standard_name in ['time']:
39
+ return 'T'
40
+ elif 'altitude' in standard_name or 'height' in standard_name or standard_name == 'air_pressure':
41
+ return 'Z'
42
+
43
+ # Check coordinate name patterns
44
+ if any(pattern in name_lower for pattern in ['lon', 'x']):
45
+ return 'X'
46
+ elif any(pattern in name_lower for pattern in ['lat', 'y']):
47
+ return 'Y'
48
+ elif any(pattern in name_lower for pattern in ['time', 't']):
49
+ return 'T'
50
+ elif any(pattern in name_lower for pattern in ['lev', 'level', 'pressure', 'z', 'height', 'alt']):
51
+ return 'Z'
52
+
53
+ # Check units
54
+ units = attrs.get('units', '').lower()
55
+ if any(unit in units for unit in ['degree_east', 'degrees_east', 'degree_e']):
56
+ return 'X'
57
+ elif any(unit in units for unit in ['degree_north', 'degrees_north', 'degree_n']):
58
+ return 'Y'
59
+ elif any(unit in units for unit in ['days since', 'hours since', 'seconds since']):
60
+ return 'T'
61
+ elif any(unit in units for unit in ['pa', 'hpa', 'mbar', 'mb', 'm', 'km']):
62
+ return 'Z'
63
+
64
+ return None
65
+
66
+
67
+ def identify_coordinates(da: xr.DataArray) -> Dict[str, str]:
68
+ """
69
+ Identify coordinate types in a DataArray.
70
+
71
+ Args:
72
+ da: Input DataArray
73
+
74
+ Returns:
75
+ Dictionary mapping axis type to coordinate name
76
+ """
77
+ coords = {}
78
+
79
+ for coord_name in da.dims:
80
+ axis = guess_cf_axis(da, coord_name)
81
+ if axis:
82
+ coords[axis] = coord_name
83
+
84
+ return coords
85
+
86
+
87
+ def get_crs(da: xr.DataArray) -> Optional[CRS]:
88
+ """
89
+ Extract CRS information from a DataArray.
90
+
91
+ Args:
92
+ da: Input DataArray
93
+
94
+ Returns:
95
+ pyproj CRS object or None
96
+ """
97
+ # Check for grid_mapping attribute
98
+ grid_mapping = da.attrs.get('grid_mapping')
99
+ if grid_mapping and grid_mapping in da.coords:
100
+ gm_var = da.coords[grid_mapping]
101
+
102
+ # Try to construct CRS from grid mapping attributes
103
+ try:
104
+ crs_attrs = dict(gm_var.attrs)
105
+ return CRS.from_cf(crs_attrs)
106
+ except:
107
+ pass
108
+
109
+ # Check for crs coordinate
110
+ if 'crs' in da.coords:
111
+ try:
112
+ return CRS.from_cf(dict(da.coords['crs'].attrs))
113
+ except:
114
+ pass
115
+
116
+ # Check for spatial_ref coordinate (common in rioxarray)
117
+ if 'spatial_ref' in da.coords:
118
+ try:
119
+ spatial_ref = da.coords['spatial_ref']
120
+ if hasattr(spatial_ref, 'spatial_ref'):
121
+ return CRS.from_wkt(spatial_ref.spatial_ref)
122
+ elif 'crs_wkt' in spatial_ref.attrs:
123
+ return CRS.from_wkt(spatial_ref.attrs['crs_wkt'])
124
+ except:
125
+ pass
126
+
127
+ # Default to geographic CRS if we have lat/lon
128
+ coords = identify_coordinates(da)
129
+ if 'X' in coords and 'Y' in coords:
130
+ x_coord = da.coords[coords['X']]
131
+ y_coord = da.coords[coords['Y']]
132
+
133
+ # Check if coordinates look like geographic
134
+ x_range = float(x_coord.max()) - float(x_coord.min())
135
+ y_range = float(y_coord.max()) - float(y_coord.min())
136
+
137
+ if -180 <= x_coord.min() <= x_coord.max() <= 360 and -90 <= y_coord.min() <= y_coord.max() <= 90:
138
+ return CRS.from_epsg(4326) # WGS84
139
+
140
+ return None
141
+
142
+
143
+ def is_geographic(da: xr.DataArray) -> bool:
144
+ """Check if DataArray uses geographic coordinates."""
145
+ crs = get_crs(da)
146
+ if crs:
147
+ return crs.is_geographic
148
+
149
+ # Fallback: check coordinate ranges
150
+ coords = identify_coordinates(da)
151
+ if 'X' in coords and 'Y' in coords:
152
+ x_coord = da.coords[coords['X']]
153
+ y_coord = da.coords[coords['Y']]
154
+
155
+ x_range = float(x_coord.max()) - float(x_coord.min())
156
+ y_range = float(y_coord.max()) - float(y_coord.min())
157
+
158
+ return (-180 <= x_coord.min() <= x_coord.max() <= 360 and
159
+ -90 <= y_coord.min() <= y_coord.max() <= 90)
160
+
161
+ return False
162
+
163
+
164
+ def ensure_longitude_range(da: xr.DataArray, range_type: str = '180') -> xr.DataArray:
165
+ """
166
+ Ensure longitude coordinates are in the specified range.
167
+
168
+ Args:
169
+ da: Input DataArray
170
+ range_type: '180' for [-180, 180] or '360' for [0, 360]
171
+
172
+ Returns:
173
+ DataArray with adjusted longitude coordinates
174
+ """
175
+ coords = identify_coordinates(da)
176
+ if 'X' not in coords:
177
+ return da
178
+
179
+ x_coord = coords['X']
180
+ da_copy = da.copy()
181
+
182
+ if range_type == '180':
183
+ # Convert to [-180, 180]
184
+ da_copy.coords[x_coord] = ((da_copy.coords[x_coord] + 180) % 360) - 180
185
+ elif range_type == '360':
186
+ # Convert to [0, 360]
187
+ da_copy.coords[x_coord] = da_copy.coords[x_coord] % 360
188
+
189
+ # Sort by longitude if needed
190
+ if x_coord in da_copy.dims:
191
+ da_copy = da_copy.sortby(x_coord)
192
+
193
+ return da_copy
194
+
195
+
196
+ def get_time_bounds(da: xr.DataArray) -> Optional[Tuple[Any, Any]]:
197
+ """Get time bounds from a DataArray."""
198
+ coords = identify_coordinates(da)
199
+ if 'T' not in coords:
200
+ return None
201
+
202
+ time_coord = da.coords[coords['T']]
203
+ return (time_coord.min().values, time_coord.max().values)
204
+
205
+
206
+ def format_value(value: Any, coord_name: str = '') -> str:
207
+ """Format a coordinate value for display."""
208
+ if isinstance(value, np.datetime64):
209
+ return str(value)[:19] # Remove nanoseconds
210
+ elif isinstance(value, (int, np.integer)):
211
+ return str(value)
212
+ elif isinstance(value, (float, np.floating)):
213
+ if 'time' in coord_name.lower():
214
+ return f"{value:.1f}"
215
+ else:
216
+ return f"{value:.3f}"
217
+ else:
218
+ return str(value)
tests/test_io.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for I/O operations."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ import xarray as xr
6
+ import tempfile
7
+ import os
8
+ from tensorview.io import open_any, list_variables, get_dataarray, detect_engine
9
+
10
+
11
+ def create_sample_netcdf():
12
+ """Create a sample NetCDF file for testing."""
13
+ # Create sample data
14
+ lons = np.arange(-180, 180, 2.5)
15
+ lats = np.arange(-90, 90, 2.5)
16
+ times = np.arange(0, 10)
17
+
18
+ lon_grid, lat_grid = np.meshgrid(lons, lats)
19
+
20
+ # Create sample temperature data
21
+ temp_data = np.random.randn(len(times), len(lats), len(lons)) + 20
22
+
23
+ # Create xarray Dataset
24
+ ds = xr.Dataset({
25
+ 'temperature': (['time', 'lat', 'lon'], temp_data, {
26
+ 'units': 'degrees_C',
27
+ 'long_name': 'Temperature',
28
+ 'standard_name': 'air_temperature'
29
+ })
30
+ }, coords={
31
+ 'time': ('time', times, {'units': 'days since 2000-01-01'}),
32
+ 'lat': ('lat', lats, {'units': 'degrees_north', 'long_name': 'Latitude'}),
33
+ 'lon': ('lon', lons, {'units': 'degrees_east', 'long_name': 'Longitude'})
34
+ })
35
+
36
+ # Save to temporary file
37
+ temp_file = tempfile.NamedTemporaryFile(suffix='.nc', delete=False)
38
+ ds.to_netcdf(temp_file.name)
39
+ temp_file.close()
40
+
41
+ return temp_file.name
42
+
43
+
44
+ def test_detect_engine():
45
+ """Test engine detection."""
46
+ assert detect_engine('test.nc') == 'h5netcdf'
47
+ assert detect_engine('test.grib') == 'cfgrib'
48
+ assert detect_engine('test.zarr') == 'zarr'
49
+ assert detect_engine('test.h5') == 'h5netcdf'
50
+
51
+
52
+ def test_open_netcdf():
53
+ """Test opening NetCDF files."""
54
+ nc_file = create_sample_netcdf()
55
+
56
+ try:
57
+ # Test opening
58
+ handle = open_any(nc_file)
59
+ assert handle is not None
60
+ assert handle.engine in ['h5netcdf', 'netcdf4']
61
+
62
+ # Test variable listing
63
+ variables = list_variables(handle)
64
+ assert len(variables) == 1
65
+ assert variables[0].name == 'temperature'
66
+ assert variables[0].units == 'degrees_C'
67
+
68
+ # Test getting data array
69
+ da = get_dataarray(handle, 'temperature')
70
+ assert da.name == 'temperature'
71
+ assert len(da.dims) == 3
72
+ assert 'time' in da.dims
73
+ assert 'lat' in da.dims
74
+ assert 'lon' in da.dims
75
+
76
+ # Clean up
77
+ handle.close()
78
+
79
+ finally:
80
+ os.unlink(nc_file)
81
+
82
+
83
+ def test_invalid_file():
84
+ """Test error handling for invalid files."""
85
+ with pytest.raises(RuntimeError):
86
+ open_any('nonexistent_file.nc')
87
+
88
+
89
+ def test_invalid_variable():
90
+ """Test error handling for invalid variables."""
91
+ nc_file = create_sample_netcdf()
92
+
93
+ try:
94
+ handle = open_any(nc_file)
95
+
96
+ with pytest.raises(ValueError):
97
+ get_dataarray(handle, 'nonexistent_variable')
98
+
99
+ handle.close()
100
+
101
+ finally:
102
+ os.unlink(nc_file)
tests/test_plot.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for plotting functionality."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ import xarray as xr
6
+ import matplotlib.pyplot as plt
7
+ from tensorview.plot import plot_1d, plot_2d, setup_matplotlib
8
+
9
+
10
+ def create_sample_data():
11
+ """Create sample data for testing."""
12
+ # 1D data
13
+ x = np.linspace(0, 10, 100)
14
+ y = np.sin(x)
15
+ da_1d = xr.DataArray(y, coords={'x': x}, dims=['x'],
16
+ attrs={'units': 'm/s', 'long_name': 'Sine Wave'})
17
+
18
+ # 2D data
19
+ lons = np.linspace(-10, 10, 20)
20
+ lats = np.linspace(-10, 10, 15)
21
+ lon_grid, lat_grid = np.meshgrid(lons, lats)
22
+ temp_data = np.sin(lon_grid/5) * np.cos(lat_grid/5) + np.random.randn(*lat_grid.shape) * 0.1
23
+
24
+ da_2d = xr.DataArray(temp_data,
25
+ coords={'lat': lats, 'lon': lons},
26
+ dims=['lat', 'lon'],
27
+ attrs={'units': 'degrees_C', 'long_name': 'Temperature'})
28
+
29
+ return da_1d, da_2d
30
+
31
+
32
+ def test_setup_matplotlib():
33
+ """Test matplotlib setup."""
34
+ setup_matplotlib()
35
+ assert plt.get_backend() == 'Agg'
36
+
37
+
38
+ def test_plot_1d():
39
+ """Test 1D plotting."""
40
+ da_1d, _ = create_sample_data()
41
+
42
+ fig = plot_1d(da_1d)
43
+ assert fig is not None
44
+ assert len(fig.axes) == 1
45
+
46
+ ax = fig.axes[0]
47
+ assert len(ax.lines) == 1
48
+ assert ax.get_xlabel() == 'x ()'
49
+ assert 'Sine Wave' in ax.get_title()
50
+
51
+ plt.close(fig)
52
+
53
+
54
+ def test_plot_2d():
55
+ """Test 2D plotting."""
56
+ _, da_2d = create_sample_data()
57
+
58
+ # Test image plot
59
+ fig = plot_2d(da_2d, kind="image")
60
+ assert fig is not None
61
+ assert len(fig.axes) >= 1 # Plot axis + possibly colorbar axis
62
+
63
+ ax = fig.axes[0]
64
+ assert ax.get_xlabel() == 'lon ()'
65
+ assert ax.get_ylabel() == 'lat ()'
66
+ assert 'Temperature' in ax.get_title()
67
+
68
+ plt.close(fig)
69
+
70
+ # Test contour plot
71
+ fig = plot_2d(da_2d, kind="contour")
72
+ assert fig is not None
73
+ plt.close(fig)
74
+
75
+
76
+ def test_plot_styling():
77
+ """Test plot styling options."""
78
+ da_1d, da_2d = create_sample_data()
79
+
80
+ # Test 1D styling
81
+ fig = plot_1d(da_1d, color='red', linewidth=2, grid=False)
82
+ ax = fig.axes[0]
83
+ assert ax.lines[0].get_color() == 'red'
84
+ assert ax.lines[0].get_linewidth() == 2
85
+ plt.close(fig)
86
+
87
+ # Test 2D styling
88
+ fig = plot_2d(da_2d, cmap='plasma', vmin=-1, vmax=1)
89
+ assert fig is not None
90
+ plt.close(fig)
91
+
92
+
93
+ def test_auto_dimension_detection():
94
+ """Test automatic dimension detection."""
95
+ _, da_2d = create_sample_data()
96
+
97
+ # Should work without specifying dimensions
98
+ fig = plot_2d(da_2d)
99
+ assert fig is not None
100
+ plt.close(fig)