ncview / app.py
Nipun's picture
Simplify app for HF compatibility - remove complex dependencies
31e3def
"""TensorView - Simple HF-compatible NetCDF viewer."""
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
# Set matplotlib backend for server environment
plt.switch_backend('Agg')
# Global state
current_dataset = None
current_variables = []
def load_file(file):
"""Load NetCDF file and return info."""
global current_dataset, current_variables
if file is None:
return "No file uploaded.", [], ""
try:
# Load dataset
current_dataset = xr.open_dataset(file.name, chunks="auto")
current_variables = list(current_dataset.data_vars.keys())
# Create summary
info = f"""βœ… **Dataset loaded successfully!**
**File:** {os.path.basename(file.name)}
**Dimensions:** {dict(current_dataset.dims)}
**Variables:** {len(current_variables)}
### Available Variables:
"""
for var in current_variables[:10]: # Show first 10
da = current_dataset[var]
info += f"- **{var}**: {da.shape} [{da.attrs.get('units', 'N/A')}] - {da.attrs.get('long_name', var)}\n"
return info, current_variables, "Dataset loaded successfully!"
except Exception as e:
return f"❌ Error: {str(e)}", [], ""
def create_simple_plot(variable, plot_type, colormap):
"""Create a simple plot."""
global current_dataset
if current_dataset is None or not variable:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'Please load a dataset first',
ha='center', va='center', transform=ax.transAxes, fontsize=14)
return fig
try:
da = current_dataset[variable]
# Reduce dimensions to 2D by taking first slice of extra dimensions
while len(da.dims) > 2:
# Find first non-spatial dimension and take slice 0
for dim in da.dims:
if not any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
da = da.isel({dim: 0})
break
else:
# If no non-spatial dims, just take first dimension
da = da.isel({da.dims[0]: 0})
# Create plot
fig, ax = plt.subplots(figsize=(12, 8))
if len(da.dims) == 2:
# Calculate good color limits using percentiles
values = da.values
finite_values = values[np.isfinite(values)]
if len(finite_values) > 0:
vmin = np.percentile(finite_values, 2)
vmax = np.percentile(finite_values, 98)
if vmin == vmax:
vmin, vmax = finite_values.min(), finite_values.max()
else:
vmin, vmax = 0, 1
# Create plot based on type
if plot_type == "Map" and any(coord in da.dims for coord in ['lat', 'latitude']) and any(coord in da.dims for coord in ['lon', 'longitude']):
# Try to create a geographic plot
try:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
fig, ax = plt.subplots(figsize=(12, 8),
subplot_kw={'projection': ccrs.PlateCarree()})
# Get lat/lon coordinates
lat_dim = next(dim for dim in da.dims if 'lat' in dim.lower())
lon_dim = next(dim for dim in da.dims if 'lon' in dim.lower())
lons = da.coords[lon_dim].values
lats = da.coords[lat_dim].values
# Create the plot
im = ax.pcolormesh(lons, lats, da.values,
transform=ccrs.PlateCarree(),
cmap=colormap, vmin=vmin, vmax=vmax, shading='auto')
# Add map features
ax.coastlines(resolution='50m')
ax.gridlines(draw_labels=True, alpha=0.5)
ax.add_feature(cfeature.BORDERS, linewidth=0.5)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8)
cbar.set_label(f"{variable} ({da.attrs.get('units', '')})")
except ImportError:
# Fallback to regular image plot if cartopy not available
im = ax.imshow(da.values, aspect='auto', origin='lower',
cmap=colormap, vmin=vmin, vmax=vmax)
ax.set_xlabel(da.dims[1])
ax.set_ylabel(da.dims[0])
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})")
else:
# Regular 2D image plot
im = ax.imshow(da.values, aspect='auto', origin='lower',
cmap=colormap, vmin=vmin, vmax=vmax)
ax.set_xlabel(da.dims[1])
ax.set_ylabel(da.dims[0])
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})")
ax.set_title(f"{da.attrs.get('long_name', variable)}")
elif len(da.dims) == 1:
# 1D line plot
ax.plot(da.coords[da.dims[0]], da.values)
ax.set_xlabel(f"{da.dims[0]} ({da.coords[da.dims[0]].attrs.get('units', '')})")
ax.set_ylabel(f"{variable} ({da.attrs.get('units', '')})")
ax.set_title(f"{da.attrs.get('long_name', variable)}")
ax.grid(True, alpha=0.3)
else:
ax.text(0.5, 0.5, f'Cannot plot {len(da.dims)}D data',
ha='center', va='center', transform=ax.transAxes)
plt.tight_layout()
return fig
except Exception as e:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, f'Error creating plot:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes, color='red')
return fig
# Create Gradio interface
with gr.Blocks(title="TensorView - NetCDF Viewer") as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>🌍 TensorView</h1>
<p><strong>Simple NetCDF/HDF viewer for scientific data</strong></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Upload Data")
file_input = gr.File(
label="Upload NetCDF file",
file_types=[".nc", ".netcdf", ".hdf", ".h5"]
)
gr.Markdown("### 🎨 Plot Settings")
variable_dropdown = gr.Dropdown(
label="Select Variable",
choices=[],
interactive=True
)
plot_type_radio = gr.Radio(
label="Plot Type",
choices=["2D Image", "Map"],
value="2D Image"
)
colormap_dropdown = gr.Dropdown(
label="Colormap",
choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds"],
value="viridis"
)
plot_button = gr.Button("Create Plot", variant="primary")
with gr.Column(scale=2):
file_info = gr.Markdown("Upload a NetCDF file to begin.")
plot_output = gr.Plot()
# Event handlers
file_input.upload(
fn=load_file,
inputs=[file_input],
outputs=[file_info, variable_dropdown, gr.Textbox(visible=False)]
)
plot_button.click(
fn=create_simple_plot,
inputs=[variable_dropdown, plot_type_radio, colormap_dropdown],
outputs=[plot_output]
)
if __name__ == "__main__":
demo.launch()