|
|
"""TensorView - Simple HF-compatible NetCDF viewer.""" |
|
|
|
|
|
import os |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import xarray as xr |
|
|
|
|
|
|
|
|
plt.switch_backend('Agg') |
|
|
|
|
|
|
|
|
current_dataset = None |
|
|
current_variables = [] |
|
|
|
|
|
def load_file(file): |
|
|
"""Load NetCDF file and return info.""" |
|
|
global current_dataset, current_variables |
|
|
|
|
|
if file is None: |
|
|
return "No file uploaded.", [], "" |
|
|
|
|
|
try: |
|
|
|
|
|
current_dataset = xr.open_dataset(file.name, chunks="auto") |
|
|
current_variables = list(current_dataset.data_vars.keys()) |
|
|
|
|
|
|
|
|
info = f"""β
**Dataset loaded successfully!** |
|
|
|
|
|
**File:** {os.path.basename(file.name)} |
|
|
**Dimensions:** {dict(current_dataset.dims)} |
|
|
**Variables:** {len(current_variables)} |
|
|
|
|
|
### Available Variables: |
|
|
""" |
|
|
for var in current_variables[:10]: |
|
|
da = current_dataset[var] |
|
|
info += f"- **{var}**: {da.shape} [{da.attrs.get('units', 'N/A')}] - {da.attrs.get('long_name', var)}\n" |
|
|
|
|
|
return info, current_variables, "Dataset loaded successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
return f"β Error: {str(e)}", [], "" |
|
|
|
|
|
def create_simple_plot(variable, plot_type, colormap): |
|
|
"""Create a simple plot.""" |
|
|
global current_dataset |
|
|
|
|
|
if current_dataset is None or not variable: |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
ax.text(0.5, 0.5, 'Please load a dataset first', |
|
|
ha='center', va='center', transform=ax.transAxes, fontsize=14) |
|
|
return fig |
|
|
|
|
|
try: |
|
|
da = current_dataset[variable] |
|
|
|
|
|
|
|
|
while len(da.dims) > 2: |
|
|
|
|
|
for dim in da.dims: |
|
|
if not any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']): |
|
|
da = da.isel({dim: 0}) |
|
|
break |
|
|
else: |
|
|
|
|
|
da = da.isel({da.dims[0]: 0}) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
|
|
|
|
if len(da.dims) == 2: |
|
|
|
|
|
values = da.values |
|
|
finite_values = values[np.isfinite(values)] |
|
|
if len(finite_values) > 0: |
|
|
vmin = np.percentile(finite_values, 2) |
|
|
vmax = np.percentile(finite_values, 98) |
|
|
if vmin == vmax: |
|
|
vmin, vmax = finite_values.min(), finite_values.max() |
|
|
else: |
|
|
vmin, vmax = 0, 1 |
|
|
|
|
|
|
|
|
if plot_type == "Map" and any(coord in da.dims for coord in ['lat', 'latitude']) and any(coord in da.dims for coord in ['lon', 'longitude']): |
|
|
|
|
|
try: |
|
|
import cartopy.crs as ccrs |
|
|
import cartopy.feature as cfeature |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 8), |
|
|
subplot_kw={'projection': ccrs.PlateCarree()}) |
|
|
|
|
|
|
|
|
lat_dim = next(dim for dim in da.dims if 'lat' in dim.lower()) |
|
|
lon_dim = next(dim for dim in da.dims if 'lon' in dim.lower()) |
|
|
|
|
|
lons = da.coords[lon_dim].values |
|
|
lats = da.coords[lat_dim].values |
|
|
|
|
|
|
|
|
im = ax.pcolormesh(lons, lats, da.values, |
|
|
transform=ccrs.PlateCarree(), |
|
|
cmap=colormap, vmin=vmin, vmax=vmax, shading='auto') |
|
|
|
|
|
|
|
|
ax.coastlines(resolution='50m') |
|
|
ax.gridlines(draw_labels=True, alpha=0.5) |
|
|
ax.add_feature(cfeature.BORDERS, linewidth=0.5) |
|
|
|
|
|
|
|
|
cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8) |
|
|
cbar.set_label(f"{variable} ({da.attrs.get('units', '')})") |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
im = ax.imshow(da.values, aspect='auto', origin='lower', |
|
|
cmap=colormap, vmin=vmin, vmax=vmax) |
|
|
ax.set_xlabel(da.dims[1]) |
|
|
ax.set_ylabel(da.dims[0]) |
|
|
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})") |
|
|
else: |
|
|
|
|
|
im = ax.imshow(da.values, aspect='auto', origin='lower', |
|
|
cmap=colormap, vmin=vmin, vmax=vmax) |
|
|
ax.set_xlabel(da.dims[1]) |
|
|
ax.set_ylabel(da.dims[0]) |
|
|
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})") |
|
|
|
|
|
ax.set_title(f"{da.attrs.get('long_name', variable)}") |
|
|
|
|
|
elif len(da.dims) == 1: |
|
|
|
|
|
ax.plot(da.coords[da.dims[0]], da.values) |
|
|
ax.set_xlabel(f"{da.dims[0]} ({da.coords[da.dims[0]].attrs.get('units', '')})") |
|
|
ax.set_ylabel(f"{variable} ({da.attrs.get('units', '')})") |
|
|
ax.set_title(f"{da.attrs.get('long_name', variable)}") |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
else: |
|
|
ax.text(0.5, 0.5, f'Cannot plot {len(da.dims)}D data', |
|
|
ha='center', va='center', transform=ax.transAxes) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
except Exception as e: |
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
ax.text(0.5, 0.5, f'Error creating plot:\n{str(e)}', |
|
|
ha='center', va='center', transform=ax.transAxes, color='red') |
|
|
return fig |
|
|
|
|
|
|
|
|
with gr.Blocks(title="TensorView - NetCDF Viewer") as demo: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 20px;"> |
|
|
<h1>π TensorView</h1> |
|
|
<p><strong>Simple NetCDF/HDF viewer for scientific data</strong></p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π Upload Data") |
|
|
file_input = gr.File( |
|
|
label="Upload NetCDF file", |
|
|
file_types=[".nc", ".netcdf", ".hdf", ".h5"] |
|
|
) |
|
|
|
|
|
gr.Markdown("### π¨ Plot Settings") |
|
|
variable_dropdown = gr.Dropdown( |
|
|
label="Select Variable", |
|
|
choices=[], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
plot_type_radio = gr.Radio( |
|
|
label="Plot Type", |
|
|
choices=["2D Image", "Map"], |
|
|
value="2D Image" |
|
|
) |
|
|
|
|
|
colormap_dropdown = gr.Dropdown( |
|
|
label="Colormap", |
|
|
choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds"], |
|
|
value="viridis" |
|
|
) |
|
|
|
|
|
plot_button = gr.Button("Create Plot", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
file_info = gr.Markdown("Upload a NetCDF file to begin.") |
|
|
plot_output = gr.Plot() |
|
|
|
|
|
|
|
|
file_input.upload( |
|
|
fn=load_file, |
|
|
inputs=[file_input], |
|
|
outputs=[file_info, variable_dropdown, gr.Textbox(visible=False)] |
|
|
) |
|
|
|
|
|
plot_button.click( |
|
|
fn=create_simple_plot, |
|
|
inputs=[variable_dropdown, plot_type_radio, colormap_dropdown], |
|
|
outputs=[plot_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |