""" SimEIT Dataset Visualizer A Gradio-based application for visualizing EIT (Electrical Impedance Tomography) datasets from Hugging Face Hub with interactive plots and configurations. Author: Ayman A. Ameen """ import random import numpy as np import h5py import gradio as gr import plotly.graph_objects as go from huggingface_hub import HfFileSystem # ============================================================================ # CONFIGURATION # ============================================================================ DATASET_CONFIG = { 'hf_dataset': 'AymanAmeen/SimEIT-dataset', 'hf_split': 'train', 'hf_subset': 'FourObjects', # Options: 'FourObjects' or 'CirclesOnly' } AVAILABLE_SUBSETS = ['FourObjects', 'CirclesOnly'] AVAILABLE_RESOLUTIONS = ['256', '128_log', '64_log', '32_log'] AVAILABLE_COLORMAPS = [ 'Jet', 'Viridis', 'Plasma', 'Inferno', 'Magma', 'Cividis', 'Hot', 'Cool', 'RdBu', 'RdYlBu', 'Spectral', 'Turbo', 'Blues', 'Greens', 'Reds', 'YlOrRd', 'Portland', 'Picnic' ] # ============================================================================ # DATA LOADER # ============================================================================ class HFDatasetLoader: """ Loads samples from Hugging Face dataset HDF5 file via streaming. Features: - Streams data directly from Hugging Face Hub without downloading - Implements LRU cache for frequently accessed samples - Supports lazy loading of specific resolutions """ def __init__(self, dataset_name, split="train", subset="FourObjects", cache_size=50): """ Initialize the dataset loader. Args: dataset_name: Name of the HuggingFace dataset split: Dataset split (default: "train") subset: Dataset subset (e.g., "FourObjects", "CirclesOnly") cache_size: Number of samples to cache (default: 50) """ self.dataset_name = dataset_name self.split = split self.subset = subset self.cache_size = cache_size self._cache = {} self._cache_order = [] print(f"Connecting to dataset {dataset_name} (subset: {subset}) via streaming...") # Initialize HuggingFace filesystem for streaming self.fs = HfFileSystem() self.h5_path = f"datasets/{dataset_name}/{subset}/dataset.h5" # Open HDF5 file in streaming mode and keep it open self._file_handle = self.fs.open(self.h5_path, 'rb') self.h5file = h5py.File(self._file_handle, 'r') # Get dataset size self.num_samples = self.h5file['image']['256'].shape[2] print(f"✓ Dataset connected successfully!") print(f" Total samples: {self.num_samples:,}") print(f" Cache enabled: storing last {cache_size} samples") def __del__(self): """Clean up file handles on object destruction.""" if hasattr(self, 'h5file'): self.h5file.close() if hasattr(self, '_file_handle'): self._file_handle.close() def get_sample(self, index, image_resolution=None): """ Get a specific sample by index from the HDF5 file. Args: index: Sample index to load (0 to num_samples-1) image_resolution: Specific resolution to load (e.g., '256', '128_log') If None, loads all resolutions (slower) Returns: dict: Sample data containing voltage and image data Raises: ValueError: If index is out of range """ # Check if index is out of range and clamp to last sample if index < 0 or index >= self.num_samples: print(f"⚠ Index {index} out of range [0, {self.num_samples}), using last sample {self.num_samples - 1}") index = self.num_samples - 1 # Create cache key based on index and resolution cache_key = (index, image_resolution) # Check if already in cache if cache_key in self._cache: print(f"✓ Cache hit for sample {index}, resolution {image_resolution}") return self._cache[cache_key] print(f"Loading sample {index}, resolution {image_resolution}...") sample = {} # Load voltage data (stored as [256, num_samples]) sample['volt_16'] = self.h5file['volt']['16'][:, index] # Lazy load: only load the requested image resolution if image_resolution: sample[f'image_{image_resolution}'] = self.h5file['image'][image_resolution][:, :, index] else: # Load all resolutions (backward compatibility) for res in AVAILABLE_RESOLUTIONS: sample[f'image_{res}'] = self.h5file['image'][res][:, :, index] # Add to cache self._add_to_cache(cache_key, sample) return sample def _add_to_cache(self, key, value): """ Add item to cache with LRU (Least Recently Used) eviction. Args: key: Cache key (tuple of index and resolution) value: Sample data to cache """ if key in self._cache: # Move to end (most recent) self._cache_order.remove(key) self._cache_order.append(key) else: # Add new item if len(self._cache) >= self.cache_size: # Evict oldest item oldest_key = self._cache_order.pop(0) del self._cache[oldest_key] self._cache[key] = value self._cache_order.append(key) # ============================================================================ # VISUALIZATION FUNCTIONS # ============================================================================ def create_heatmap_plot(key, index=0, colorscale='Jet'): """ Create a Plotly heatmap from dataset image. Args: key: Image resolution key (e.g., '256', '128_log') index: Sample index colorscale: Plotly colorscale name Returns: plotly.graph_objects.Figure: Heatmap figure """ global _hf_loader try: # Lazy load: only fetch the specific resolution needed sample = _hf_loader.get_sample(index, image_resolution=key) img = sample.get(f'image_{key}') if img is None: print(f"✗ Missing image_{key} in sample {index}") return go.Figure() # Convert to numpy array img = np.array(img) # Handle log-scaled images (negative values) if len(img.shape) == 2 and np.min(img) < 0: img = np.exp(img) # Convert from log back to linear # If RGB image, convert to grayscale for heatmap if len(img.shape) == 3 and img.shape[-1] == 3: img = np.mean(img, axis=2) # Normalize image values using mean and std for this sample img_mean = np.mean(img) img_std = np.std(img) if img_std > 0: # Avoid division by zero img_normalized = (img - img_mean) / img_std else: img_normalized = img - img_mean # Create heatmap fig = go.Figure(data=go.Heatmap( z=img_normalized, colorscale=colorscale, showscale=True, colorbar=dict(title="Normalized Conductivity") )) fig.update_layout( title=dict(text=f"{key} Image (Normalized) - Sample {index}", x=0.5, xanchor='center'), width=450, height=450, xaxis=dict(showticklabels=False, showgrid=False), yaxis=dict(showticklabels=False, showgrid=False, scaleanchor="x", scaleratio=1), margin=dict(l=20, r=20, t=50, b=20), autosize=False ) return fig except Exception as e: print(f"✗ Error creating heatmap for image_{key}: {e}") import traceback traceback.print_exc() return go.Figure() def draw_voltage_plot(index=0): """ Draw voltage plot from dataset. Args: index: Sample index Returns: plotly.graph_objects.Figure: Voltage plot figure """ global _hf_loader try: # Load only voltage data (no images needed) sample = _hf_loader.get_sample(index, image_resolution=None) volt_data = sample.get('volt_16') if volt_data is None: print(f"✗ Missing voltage data in sample {index}") return go.Figure() volt_data = np.array(volt_data, dtype=np.float64) if len(volt_data.shape) > 1: volt_data = volt_data.flatten() # Normalize voltage values using mean and std for this sample volt_mean = np.mean(volt_data) volt_std = np.std(volt_data) if volt_std > 0: # Avoid division by zero volt_normalized = (volt_data - volt_mean) / volt_std else: volt_normalized = volt_data - volt_mean electrodes = np.arange(1, len(volt_normalized) + 1) # Create line plot fig = go.Figure() fig.add_trace(go.Scatter( x=electrodes, y=volt_normalized, mode='lines+markers', marker=dict(size=6, color='royalblue'), line=dict(width=2, color='royalblue') )) fig.update_layout( title=dict(text=f"Voltage Measurement (Normalized) - Sample {index}", x=0.5, xanchor='center'), xaxis_title="Electrode Number (n)", yaxis_title="Normalized Voltage (a.u.)", template="plotly_white", showlegend=False, width=450, height=450, margin=dict(l=60, r=20, t=50, b=50), autosize=False ) fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray') fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray') return fig except Exception as e: print(f"✗ Error plotting voltage: {e}") return go.Figure() # ============================================================================ # GRADIO UI HELPER FUNCTIONS # ============================================================================ def get_dataset_info(): """Get current dataset information string.""" global _hf_loader return f"HuggingFace: {DATASET_CONFIG['hf_dataset']} (subset: {_hf_loader.subset}, split: {DATASET_CONFIG['hf_split']})" def get_max_index(): """Get maximum valid index in current dataset.""" global _hf_loader return _hf_loader.num_samples - 1 def generate_random_index(state): """ Generate a random valid index and update state. Args: state: Current state list of indices Returns: tuple: (random_index, updated_state) """ num = random.randint(0, get_max_index()) new_list = state + [num] return num, new_list def select_index(n, state): """ Select a specific index with validation. Args: n: Index to select state: Current state list of indices Returns: tuple: (validated_index, updated_state) """ if n is None or n == "": return "", state max_idx = get_max_index() if not (0 <= n <= max_idx): return f"Number must be between 0 and {max_idx}.", state new_list = state + [int(n)] return int(n), new_list def show_images(state, image_res, colorscale): """ Display images for the last selected index in state. Args: state: State list containing selected indices image_res: Image resolution to display colorscale: Colorscale for heatmap Returns: tuple: (image_plot, voltage_plot, status_message) """ if not state: return go.Figure(), go.Figure(), "No index selected" last_index = state[-1] return ( create_heatmap_plot(image_res, last_index, colorscale), draw_voltage_plot(last_index), f"✓ Loaded sample {last_index} with {image_res} resolution and colormap: {colorscale}" ) def generate_random_and_show(state, image_res, colorscale): """Generate random index and show corresponding images.""" num, new_list = generate_random_index(state) outputs = show_images(new_list, image_res, colorscale) return (num, new_list) + outputs def select_n_and_show(n, state, image_res, colorscale): """Select specific index and show corresponding images.""" _, new_list = select_index(n, state) outputs = show_images(new_list, image_res, colorscale) return (new_list,) + outputs def reload_dataset(subset, state, image_res, colorscale): """ Reload the dataset with a new subset and display a sample. Args: subset: New subset to load state: Current state list image_res: Image resolution colorscale: Colorscale for heatmap Returns: tuple: Updated UI components """ global _hf_loader try: # Close old loader if _hf_loader is not None: del _hf_loader # Create new loader with selected subset dataset_name = DATASET_CONFIG['hf_dataset'] split = DATASET_CONFIG['hf_split'] _hf_loader = HFDatasetLoader(dataset_name, split, subset) max_idx = get_max_index() # Update dataset info info_md = f""" # SimEIT: Dataset Visualizer **Dataset:** `{get_dataset_info()}` | **Total Samples:** {max_idx + 1:,} """ # Determine which sample to display if state and len(state) > 0: last_index = state[-1] sample_index = last_index if last_index <= max_idx else random.randint(0, max_idx) else: sample_index = random.randint(0, max_idx) # Update state with the new sample new_state = [sample_index] # Generate plots for the sample image_plot = create_heatmap_plot(image_res, sample_index, colorscale) volt_plot = draw_voltage_plot(sample_index) status_msg = f"✓ Loaded subset: {subset} ({max_idx + 1:,} samples) - Displaying sample {sample_index}" return ( info_md, gr.Number(label=f"Enter an integer (0–{max_idx})", precision=0, value=sample_index), new_state, image_plot, volt_plot, status_msg ) except Exception as e: return ( gr.Markdown(), gr.Number(), [], go.Figure(), go.Figure(), f"✗ Error loading subset {subset}: {str(e)}" ) # ============================================================================ # MAIN APPLICATION # ============================================================================ # Global dataset loader instance _hf_loader = None def create_gradio_interface(): """ Create and configure the Gradio interface. Returns: gr.Blocks: Configured Gradio application """ global _hf_loader # Initialize configuration dataset_name = DATASET_CONFIG['hf_dataset'] split = DATASET_CONFIG['hf_split'] default_subset = DATASET_CONFIG['hf_subset'] # Initialize dataset loader with default subset _hf_loader = HFDatasetLoader(dataset_name, split, default_subset) with gr.Blocks(title="SimEIT Dataset Visualizer") as demo: # Header dataset_info_display = gr.Markdown(f""" # SimEIT: Dataset Visualizer **Dataset:** `{get_dataset_info()}` | **Total Samples:** {get_max_index() + 1:,} """) # Controls section gr.Markdown("### Choose dataset subset, sample index, image resolution, and colormap") with gr.Row(): with gr.Column(): subset_selector = gr.Dropdown( choices=AVAILABLE_SUBSETS, value=default_subset, label="Select Dataset Subset" ) user_input = gr.Number( label=f"Enter an integer (0–{get_max_index()})", precision=0 ) btn_select_n = gr.Button("Confirm Number") btn_random = gr.Button("Generate Random Number") with gr.Column(): image_selector = gr.Dropdown( choices=AVAILABLE_RESOLUTIONS, value='256', label="Select Image Resolution" ) colormap_dropdown = gr.Dropdown( choices=AVAILABLE_COLORMAPS, value='Jet', label="Select Colormap" ) # State for tracking indices indices_list = gr.State(value=[]) # Visualization plots with gr.Row(equal_height=True): with gr.Column(scale=2): image_plot = gr.Plot(label="Image Heatmap") with gr.Column(scale=2): volt_plot = gr.Plot(label="Voltage Plot") # Status output status_output = gr.Textbox(label="Status", interactive=False) # Event handlers subset_selector.change( fn=reload_dataset, inputs=[subset_selector, indices_list, image_selector, colormap_dropdown], outputs=[dataset_info_display, user_input, indices_list, image_plot, volt_plot, status_output] ) btn_random.click( fn=generate_random_and_show, inputs=[indices_list, image_selector, colormap_dropdown], outputs=[user_input, indices_list, image_plot, volt_plot, status_output] ) btn_select_n.click( fn=select_n_and_show, inputs=[user_input, indices_list, image_selector, colormap_dropdown], outputs=[indices_list, image_plot, volt_plot, status_output] ) # Allow Enter key to confirm the number user_input.submit( fn=select_n_and_show, inputs=[user_input, indices_list, image_selector, colormap_dropdown], outputs=[indices_list, image_plot, volt_plot, status_output] ) image_selector.change( fn=show_images, inputs=[indices_list, image_selector, colormap_dropdown], outputs=[image_plot, volt_plot, status_output] ) colormap_dropdown.change( fn=show_images, inputs=[indices_list, image_selector, colormap_dropdown], outputs=[image_plot, volt_plot, status_output] ) # Load a random example at startup demo.load( fn=generate_random_and_show, inputs=[indices_list, image_selector, colormap_dropdown], outputs=[user_input, indices_list, image_plot, volt_plot, status_output] ) # Citation section gr.HTML("""

Citation

@article{ameen2025simeit,
  title={SimEIT: A Scalable Simulation Framework for Generating Large-Scale Electrical Impedance Tomography Datasets},
  author={Ameen, Ayman A. and Mathis-Ullrich, Franziska and Kainz, Bernhard},
  year={2025},
}
""") return demo def main(): """Main entry point for the application.""" demo = create_gradio_interface() demo.launch(share=True) if __name__ == "__main__": main()