"""TensorView - Simple HF-compatible NetCDF viewer.""" import os import gradio as gr import matplotlib.pyplot as plt import numpy as np import xarray as xr # Set matplotlib backend for server environment plt.switch_backend('Agg') # Global state current_dataset = None current_variables = [] def load_file(file): """Load NetCDF file and return info.""" global current_dataset, current_variables if file is None: return "No file uploaded.", [], "" try: # Load dataset current_dataset = xr.open_dataset(file.name, chunks="auto") current_variables = list(current_dataset.data_vars.keys()) # Create summary info = f"""✅ **Dataset loaded successfully!** **File:** {os.path.basename(file.name)} **Dimensions:** {dict(current_dataset.dims)} **Variables:** {len(current_variables)} ### Available Variables: """ for var in current_variables[:10]: # Show first 10 da = current_dataset[var] info += f"- **{var}**: {da.shape} [{da.attrs.get('units', 'N/A')}] - {da.attrs.get('long_name', var)}\n" return info, current_variables, "Dataset loaded successfully!" except Exception as e: return f"❌ Error: {str(e)}", [], "" def create_simple_plot(variable, plot_type, colormap): """Create a simple plot.""" global current_dataset if current_dataset is None or not variable: fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, 'Please load a dataset first', ha='center', va='center', transform=ax.transAxes, fontsize=14) return fig try: da = current_dataset[variable] # Reduce dimensions to 2D by taking first slice of extra dimensions while len(da.dims) > 2: # Find first non-spatial dimension and take slice 0 for dim in da.dims: if not any(spatial in dim.lower() for spatial in ['lat', 'lon', 'x', 'y']): da = da.isel({dim: 0}) break else: # If no non-spatial dims, just take first dimension da = da.isel({da.dims[0]: 0}) # Create plot fig, ax = plt.subplots(figsize=(12, 8)) if len(da.dims) == 2: # Calculate good color limits using percentiles values = da.values finite_values = values[np.isfinite(values)] if len(finite_values) > 0: vmin = np.percentile(finite_values, 2) vmax = np.percentile(finite_values, 98) if vmin == vmax: vmin, vmax = finite_values.min(), finite_values.max() else: vmin, vmax = 0, 1 # Create plot based on type if plot_type == "Map" and any(coord in da.dims for coord in ['lat', 'latitude']) and any(coord in da.dims for coord in ['lon', 'longitude']): # Try to create a geographic plot try: import cartopy.crs as ccrs import cartopy.feature as cfeature fig, ax = plt.subplots(figsize=(12, 8), subplot_kw={'projection': ccrs.PlateCarree()}) # Get lat/lon coordinates lat_dim = next(dim for dim in da.dims if 'lat' in dim.lower()) lon_dim = next(dim for dim in da.dims if 'lon' in dim.lower()) lons = da.coords[lon_dim].values lats = da.coords[lat_dim].values # Create the plot im = ax.pcolormesh(lons, lats, da.values, transform=ccrs.PlateCarree(), cmap=colormap, vmin=vmin, vmax=vmax, shading='auto') # Add map features ax.coastlines(resolution='50m') ax.gridlines(draw_labels=True, alpha=0.5) ax.add_feature(cfeature.BORDERS, linewidth=0.5) # Add colorbar cbar = plt.colorbar(im, ax=ax, orientation='horizontal', pad=0.05, shrink=0.8) cbar.set_label(f"{variable} ({da.attrs.get('units', '')})") except ImportError: # Fallback to regular image plot if cartopy not available im = ax.imshow(da.values, aspect='auto', origin='lower', cmap=colormap, vmin=vmin, vmax=vmax) ax.set_xlabel(da.dims[1]) ax.set_ylabel(da.dims[0]) plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})") else: # Regular 2D image plot im = ax.imshow(da.values, aspect='auto', origin='lower', cmap=colormap, vmin=vmin, vmax=vmax) ax.set_xlabel(da.dims[1]) ax.set_ylabel(da.dims[0]) plt.colorbar(im, ax=ax, label=f"{variable} ({da.attrs.get('units', '')})") ax.set_title(f"{da.attrs.get('long_name', variable)}") elif len(da.dims) == 1: # 1D line plot ax.plot(da.coords[da.dims[0]], da.values) ax.set_xlabel(f"{da.dims[0]} ({da.coords[da.dims[0]].attrs.get('units', '')})") ax.set_ylabel(f"{variable} ({da.attrs.get('units', '')})") ax.set_title(f"{da.attrs.get('long_name', variable)}") ax.grid(True, alpha=0.3) else: ax.text(0.5, 0.5, f'Cannot plot {len(da.dims)}D data', ha='center', va='center', transform=ax.transAxes) plt.tight_layout() return fig except Exception as e: fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, f'Error creating plot:\n{str(e)}', ha='center', va='center', transform=ax.transAxes, color='red') return fig # Create Gradio interface with gr.Blocks(title="TensorView - NetCDF Viewer") as demo: gr.HTML("""

🌍 TensorView

Simple NetCDF/HDF viewer for scientific data

""") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📁 Upload Data") file_input = gr.File( label="Upload NetCDF file", file_types=[".nc", ".netcdf", ".hdf", ".h5"] ) gr.Markdown("### 🎨 Plot Settings") variable_dropdown = gr.Dropdown( label="Select Variable", choices=[], interactive=True ) plot_type_radio = gr.Radio( label="Plot Type", choices=["2D Image", "Map"], value="2D Image" ) colormap_dropdown = gr.Dropdown( label="Colormap", choices=["viridis", "plasma", "coolwarm", "RdBu_r", "Blues", "Reds"], value="viridis" ) plot_button = gr.Button("Create Plot", variant="primary") with gr.Column(scale=2): file_info = gr.Markdown("Upload a NetCDF file to begin.") plot_output = gr.Plot() # Event handlers file_input.upload( fn=load_file, inputs=[file_input], outputs=[file_info, variable_dropdown, gr.Textbox(visible=False)] ) plot_button.click( fn=create_simple_plot, inputs=[variable_dropdown, plot_type_radio, colormap_dropdown], outputs=[plot_output] ) if __name__ == "__main__": demo.launch()