File size: 8,022 Bytes
31e3def 433dab5 3aaaab0 433dab5 31e3def 433dab5 31e3def 3aaaab0 433dab5 31e3def 3aaaab0 31e3def 433dab5 31e3def 433dab5 3aaaab0 31e3def 3aaaab0 433dab5 31e3def 3aaaab0 433dab5 3aaaab0 31e3def 3aaaab0 31e3def 3aaaab0 31e3def 3aaaab0 31e3def 433dab5 31e3def 433dab5 31e3def 433dab5 3aaaab0 31e3def 3aaaab0 31e3def 3aaaab0 31e3def 433dab5 31e3def 433dab5 31e3def 433dab5 31e3def 3aaaab0 433dab5 31e3def 433dab5 31e3def 3aaaab0 433dab5 31e3def 433dab5 3aaaab0 31e3def 433dab5 31e3def 3aaaab0 31e3def 3aaaab0 433dab5 3aaaab0 31e3def 3aaaab0 31e3def 3aaaab0 31e3def 433dab5 31e3def 3aaaab0 31e3def 3aaaab0 31e3def 3aaaab0 31e3def |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""TensorView - Simple HF-compatible NetCDF viewer."""
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
# Set matplotlib backend for server environment
plt.switch_backend('Agg')
# Global state
current_dataset = None
current_variables = []
def load_file(file):
"""Load NetCDF file and return info."""
global current_dataset, current_variables
if file is None:
return "No file uploaded.", [], ""
try:
# Load dataset
current_dataset = xr.open_dataset(file.name, chunks="auto")
current_variables = list(current_dataset.data_vars.keys())
# Create summary
info = f"""β
**Dataset loaded successfully!**
**File:** {os.path.basename(file.name)}
**Dimensions:** {dict(current_dataset.dims)}
**Variables:** {len(current_variables)}
### Available Variables:
"""
for var in current_variables[:10]: # Show first 10
da = current_dataset[var]
info += f"- **{var}**: {da.shape} [{da.attrs.get('units', 'N/A')}] - {da.attrs.get('long_name', var)}\n"
return info, current_variables, "Dataset loaded successfully!"
except Exception as e:
return f"β Error: {str(e)}", [], ""
def create_simple_plot(variable, plot_type, colormap):
"""Create a simple plot."""
global current_dataset
if current_dataset is None or not variable:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'Please load a dataset first',
ha='center', va='center', transform=ax.transAxes, fontsize=14)
return fig
try:
da = current_dataset[variable]
# Reduce dimensions to 2D by taking first slice of extra dimensions
while len(da.dims) > 2:
# Find first non-spatial dimension and take slice 0
for dim in da.dims:
if not any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']):
da = da.isel({dim: 0})
break
else:
# If no non-spatial dims, just take first dimension
da = da.isel({da.dims[0]: 0})
# Create plot
fig, ax = plt.subplots(figsize=(12, 8))
if len(da.dims) == 2:
# Calculate good color limits using percentiles
values = da.values
finite_values = values[np.isfinite(values)]
if len(finite_values) > 0:
vmin = np.percentile(finite_values, 2)
vmax = np.percentile(finite_values, 98)
if vmin == vmax:
vmin, vmax = finite_values.min(), finite_values.max()
else:
vmin, vmax = 0, 1
# Create plot based on type
if plot_type == "Map" and any(coord in da.dims for coord in ['lat', 'latitude']) and any(coord in da.dims for coord in ['lon', 'longitude']):
# Try to create a geographic plot
try:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
fig, ax = plt.subplots(figsize=(12, 8),
subplot_kw={'projection': ccrs.PlateCarree()})
# Get lat/lon coordinates
lat_dim = next(dim for dim in da.dims if 'lat' in dim.lower())
lon_dim = next(dim for dim in da.dims if 'lon' in dim.lower())
lons = da.coords[lon_dim].values
lats = da.coords[lat_dim].values
# Create the plot
im = ax.pcolormesh(lons, lats, da.values,
transform=ccrs.PlateCarree(),
cmap=colormap, vmin=vmin, vmax=vmax, shading='auto')
# Add map features
ax.coastlines(resolution='50m')
ax.gridlines(draw_labels=True, alpha=0.5)
ax.add_feature(cfeature.BORDERS, linewidth=0.5)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8)
cbar.set_label(f"{variable} ({da.attrs.get('units', '')})")
except ImportError:
# Fallback to regular image plot if cartopy not available
im = ax.imshow(da.values, aspect='auto', origin='lower',
cmap=colormap, vmin=vmin, vmax=vmax)
ax.set_xlabel(da.dims[1])
ax.set_ylabel(da.dims[0])
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})")
else:
# Regular 2D image plot
im = ax.imshow(da.values, aspect='auto', origin='lower',
cmap=colormap, vmin=vmin, vmax=vmax)
ax.set_xlabel(da.dims[1])
ax.set_ylabel(da.dims[0])
plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})")
ax.set_title(f"{da.attrs.get('long_name', variable)}")
elif len(da.dims) == 1:
# 1D line plot
ax.plot(da.coords[da.dims[0]], da.values)
ax.set_xlabel(f"{da.dims[0]} ({da.coords[da.dims[0]].attrs.get('units', '')})")
ax.set_ylabel(f"{variable} ({da.attrs.get('units', '')})")
ax.set_title(f"{da.attrs.get('long_name', variable)}")
ax.grid(True, alpha=0.3)
else:
ax.text(0.5, 0.5, f'Cannot plot {len(da.dims)}D data',
ha='center', va='center', transform=ax.transAxes)
plt.tight_layout()
return fig
except Exception as e:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, f'Error creating plot:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes, color='red')
return fig
# Create Gradio interface
with gr.Blocks(title="TensorView - NetCDF Viewer") as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>π TensorView</h1>
<p><strong>Simple NetCDF/HDF viewer for scientific data</strong></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### π Upload Data")
file_input = gr.File(
label="Upload NetCDF file",
file_types=[".nc", ".netcdf", ".hdf", ".h5"]
)
gr.Markdown("### π¨ Plot Settings")
variable_dropdown = gr.Dropdown(
label="Select Variable",
choices=[],
interactive=True
)
plot_type_radio = gr.Radio(
label="Plot Type",
choices=["2D Image", "Map"],
value="2D Image"
)
colormap_dropdown = gr.Dropdown(
label="Colormap",
choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds"],
value="viridis"
)
plot_button = gr.Button("Create Plot", variant="primary")
with gr.Column(scale=2):
file_info = gr.Markdown("Upload a NetCDF file to begin.")
plot_output = gr.Plot()
# Event handlers
file_input.upload(
fn=load_file,
inputs=[file_input],
outputs=[file_info, variable_dropdown, gr.Textbox(visible=False)]
)
plot_button.click(
fn=create_simple_plot,
inputs=[variable_dropdown, plot_type_radio, colormap_dropdown],
outputs=[plot_output]
)
if __name__ == "__main__":
demo.launch() |