Spaces:
Sleeping
Sleeping
| """ | |
| SimEIT Dataset Visualizer | |
| A Gradio-based application for visualizing EIT (Electrical Impedance Tomography) datasets | |
| from Hugging Face Hub with interactive plots and configurations. | |
| Author: Ayman A. Ameen | |
| """ | |
| import random | |
| import numpy as np | |
| import h5py | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from huggingface_hub import HfFileSystem | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| DATASET_CONFIG = { | |
| 'hf_dataset': 'AymanAmeen/SimEIT-dataset', | |
| 'hf_split': 'train', | |
| 'hf_subset': 'FourObjects', # Options: 'FourObjects' or 'CirclesOnly' | |
| } | |
| AVAILABLE_SUBSETS = ['FourObjects', 'CirclesOnly'] | |
| AVAILABLE_RESOLUTIONS = ['256', '128_log', '64_log', '32_log'] | |
| AVAILABLE_COLORMAPS = [ | |
| 'Jet', 'Viridis', 'Plasma', 'Inferno', 'Magma', 'Cividis', | |
| 'Hot', 'Cool', 'RdBu', 'RdYlBu', 'Spectral', 'Turbo', | |
| 'Blues', 'Greens', 'Reds', 'YlOrRd', 'Portland', 'Picnic' | |
| ] | |
| # ============================================================================ | |
| # DATA LOADER | |
| # ============================================================================ | |
| class HFDatasetLoader: | |
| """ | |
| Loads samples from Hugging Face dataset HDF5 file via streaming. | |
| Features: | |
| - Streams data directly from Hugging Face Hub without downloading | |
| - Implements LRU cache for frequently accessed samples | |
| - Supports lazy loading of specific resolutions | |
| """ | |
| def __init__(self, dataset_name, split="train", subset="FourObjects", cache_size=50): | |
| """ | |
| Initialize the dataset loader. | |
| Args: | |
| dataset_name: Name of the HuggingFace dataset | |
| split: Dataset split (default: "train") | |
| subset: Dataset subset (e.g., "FourObjects", "CirclesOnly") | |
| cache_size: Number of samples to cache (default: 50) | |
| """ | |
| self.dataset_name = dataset_name | |
| self.split = split | |
| self.subset = subset | |
| self.cache_size = cache_size | |
| self._cache = {} | |
| self._cache_order = [] | |
| print(f"Connecting to dataset {dataset_name} (subset: {subset}) via streaming...") | |
| # Initialize HuggingFace filesystem for streaming | |
| self.fs = HfFileSystem() | |
| self.h5_path = f"datasets/{dataset_name}/{subset}/dataset.h5" | |
| # Open HDF5 file in streaming mode and keep it open | |
| self._file_handle = self.fs.open(self.h5_path, 'rb') | |
| self.h5file = h5py.File(self._file_handle, 'r') | |
| # Get dataset size | |
| self.num_samples = self.h5file['image']['256'].shape[2] | |
| print(f"✓ Dataset connected successfully!") | |
| print(f" Total samples: {self.num_samples:,}") | |
| print(f" Cache enabled: storing last {cache_size} samples") | |
| def __del__(self): | |
| """Clean up file handles on object destruction.""" | |
| if hasattr(self, 'h5file'): | |
| self.h5file.close() | |
| if hasattr(self, '_file_handle'): | |
| self._file_handle.close() | |
| def get_sample(self, index, image_resolution=None): | |
| """ | |
| Get a specific sample by index from the HDF5 file. | |
| Args: | |
| index: Sample index to load (0 to num_samples-1) | |
| image_resolution: Specific resolution to load (e.g., '256', '128_log') | |
| If None, loads all resolutions (slower) | |
| Returns: | |
| dict: Sample data containing voltage and image data | |
| Raises: | |
| ValueError: If index is out of range | |
| """ | |
| # Check if index is out of range and clamp to last sample | |
| if index < 0 or index >= self.num_samples: | |
| print(f"⚠ Index {index} out of range [0, {self.num_samples}), using last sample {self.num_samples - 1}") | |
| index = self.num_samples - 1 | |
| # Create cache key based on index and resolution | |
| cache_key = (index, image_resolution) | |
| # Check if already in cache | |
| if cache_key in self._cache: | |
| print(f"✓ Cache hit for sample {index}, resolution {image_resolution}") | |
| return self._cache[cache_key] | |
| print(f"Loading sample {index}, resolution {image_resolution}...") | |
| sample = {} | |
| # Load voltage data (stored as [256, num_samples]) | |
| sample['volt_16'] = self.h5file['volt']['16'][:, index] | |
| # Lazy load: only load the requested image resolution | |
| if image_resolution: | |
| sample[f'image_{image_resolution}'] = self.h5file['image'][image_resolution][:, :, index] | |
| else: | |
| # Load all resolutions (backward compatibility) | |
| for res in AVAILABLE_RESOLUTIONS: | |
| sample[f'image_{res}'] = self.h5file['image'][res][:, :, index] | |
| # Add to cache | |
| self._add_to_cache(cache_key, sample) | |
| return sample | |
| def _add_to_cache(self, key, value): | |
| """ | |
| Add item to cache with LRU (Least Recently Used) eviction. | |
| Args: | |
| key: Cache key (tuple of index and resolution) | |
| value: Sample data to cache | |
| """ | |
| if key in self._cache: | |
| # Move to end (most recent) | |
| self._cache_order.remove(key) | |
| self._cache_order.append(key) | |
| else: | |
| # Add new item | |
| if len(self._cache) >= self.cache_size: | |
| # Evict oldest item | |
| oldest_key = self._cache_order.pop(0) | |
| del self._cache[oldest_key] | |
| self._cache[key] = value | |
| self._cache_order.append(key) | |
| # ============================================================================ | |
| # VISUALIZATION FUNCTIONS | |
| # ============================================================================ | |
| def create_heatmap_plot(key, index=0, colorscale='Jet'): | |
| """ | |
| Create a Plotly heatmap from dataset image. | |
| Args: | |
| key: Image resolution key (e.g., '256', '128_log') | |
| index: Sample index | |
| colorscale: Plotly colorscale name | |
| Returns: | |
| plotly.graph_objects.Figure: Heatmap figure | |
| """ | |
| global _hf_loader | |
| try: | |
| # Lazy load: only fetch the specific resolution needed | |
| sample = _hf_loader.get_sample(index, image_resolution=key) | |
| img = sample.get(f'image_{key}') | |
| if img is None: | |
| print(f"✗ Missing image_{key} in sample {index}") | |
| return go.Figure() | |
| # Convert to numpy array | |
| img = np.array(img) | |
| # Handle log-scaled images (negative values) | |
| if len(img.shape) == 2 and np.min(img) < 0: | |
| img = np.exp(img) # Convert from log back to linear | |
| # If RGB image, convert to grayscale for heatmap | |
| if len(img.shape) == 3 and img.shape[-1] == 3: | |
| img = np.mean(img, axis=2) | |
| # Normalize image values using mean and std for this sample | |
| img_mean = np.mean(img) | |
| img_std = np.std(img) | |
| if img_std > 0: # Avoid division by zero | |
| img_normalized = (img - img_mean) / img_std | |
| else: | |
| img_normalized = img - img_mean | |
| # Create heatmap | |
| fig = go.Figure(data=go.Heatmap( | |
| z=img_normalized, | |
| colorscale=colorscale, | |
| showscale=True, | |
| colorbar=dict(title="Normalized Conductivity") | |
| )) | |
| fig.update_layout( | |
| title=dict(text=f"{key} Image (Normalized) - Sample {index}", x=0.5, xanchor='center'), | |
| width=450, | |
| height=450, | |
| xaxis=dict(showticklabels=False, showgrid=False), | |
| yaxis=dict(showticklabels=False, showgrid=False, scaleanchor="x", scaleratio=1), | |
| margin=dict(l=20, r=20, t=50, b=20), | |
| autosize=False | |
| ) | |
| return fig | |
| except Exception as e: | |
| print(f"✗ Error creating heatmap for image_{key}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return go.Figure() | |
| def draw_voltage_plot(index=0): | |
| """ | |
| Draw voltage plot from dataset. | |
| Args: | |
| index: Sample index | |
| Returns: | |
| plotly.graph_objects.Figure: Voltage plot figure | |
| """ | |
| global _hf_loader | |
| try: | |
| # Load only voltage data (no images needed) | |
| sample = _hf_loader.get_sample(index, image_resolution=None) | |
| volt_data = sample.get('volt_16') | |
| if volt_data is None: | |
| print(f"✗ Missing voltage data in sample {index}") | |
| return go.Figure() | |
| volt_data = np.array(volt_data, dtype=np.float64) | |
| if len(volt_data.shape) > 1: | |
| volt_data = volt_data.flatten() | |
| # Normalize voltage values using mean and std for this sample | |
| volt_mean = np.mean(volt_data) | |
| volt_std = np.std(volt_data) | |
| if volt_std > 0: # Avoid division by zero | |
| volt_normalized = (volt_data - volt_mean) / volt_std | |
| else: | |
| volt_normalized = volt_data - volt_mean | |
| electrodes = np.arange(1, len(volt_normalized) + 1) | |
| # Create line plot | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=electrodes, | |
| y=volt_normalized, | |
| mode='lines+markers', | |
| marker=dict(size=6, color='royalblue'), | |
| line=dict(width=2, color='royalblue') | |
| )) | |
| fig.update_layout( | |
| title=dict(text=f"Voltage Measurement (Normalized) - Sample {index}", x=0.5, xanchor='center'), | |
| xaxis_title="Electrode Number (n)", | |
| yaxis_title="Normalized Voltage (a.u.)", | |
| template="plotly_white", | |
| showlegend=False, | |
| width=450, | |
| height=450, | |
| margin=dict(l=60, r=20, t=50, b=50), | |
| autosize=False | |
| ) | |
| fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray') | |
| fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray') | |
| return fig | |
| except Exception as e: | |
| print(f"✗ Error plotting voltage: {e}") | |
| return go.Figure() | |
| # ============================================================================ | |
| # GRADIO UI HELPER FUNCTIONS | |
| # ============================================================================ | |
| def get_dataset_info(): | |
| """Get current dataset information string.""" | |
| global _hf_loader | |
| return f"HuggingFace: {DATASET_CONFIG['hf_dataset']} (subset: {_hf_loader.subset}, split: {DATASET_CONFIG['hf_split']})" | |
| def get_max_index(): | |
| """Get maximum valid index in current dataset.""" | |
| global _hf_loader | |
| return _hf_loader.num_samples - 1 | |
| def generate_random_index(state): | |
| """ | |
| Generate a random valid index and update state. | |
| Args: | |
| state: Current state list of indices | |
| Returns: | |
| tuple: (random_index, updated_state) | |
| """ | |
| num = random.randint(0, get_max_index()) | |
| new_list = state + [num] | |
| return num, new_list | |
| def select_index(n, state): | |
| """ | |
| Select a specific index with validation. | |
| Args: | |
| n: Index to select | |
| state: Current state list of indices | |
| Returns: | |
| tuple: (validated_index, updated_state) | |
| """ | |
| if n is None or n == "": | |
| return "", state | |
| max_idx = get_max_index() | |
| if not (0 <= n <= max_idx): | |
| return f"Number must be between 0 and {max_idx}.", state | |
| new_list = state + [int(n)] | |
| return int(n), new_list | |
| def show_images(state, image_res, colorscale): | |
| """ | |
| Display images for the last selected index in state. | |
| Args: | |
| state: State list containing selected indices | |
| image_res: Image resolution to display | |
| colorscale: Colorscale for heatmap | |
| Returns: | |
| tuple: (image_plot, voltage_plot, status_message) | |
| """ | |
| if not state: | |
| return go.Figure(), go.Figure(), "No index selected" | |
| last_index = state[-1] | |
| return ( | |
| create_heatmap_plot(image_res, last_index, colorscale), | |
| draw_voltage_plot(last_index), | |
| f"✓ Loaded sample {last_index} with {image_res} resolution and colormap: {colorscale}" | |
| ) | |
| def generate_random_and_show(state, image_res, colorscale): | |
| """Generate random index and show corresponding images.""" | |
| num, new_list = generate_random_index(state) | |
| outputs = show_images(new_list, image_res, colorscale) | |
| return (num, new_list) + outputs | |
| def select_n_and_show(n, state, image_res, colorscale): | |
| """Select specific index and show corresponding images.""" | |
| _, new_list = select_index(n, state) | |
| outputs = show_images(new_list, image_res, colorscale) | |
| return (new_list,) + outputs | |
| def reload_dataset(subset, state, image_res, colorscale): | |
| """ | |
| Reload the dataset with a new subset and display a sample. | |
| Args: | |
| subset: New subset to load | |
| state: Current state list | |
| image_res: Image resolution | |
| colorscale: Colorscale for heatmap | |
| Returns: | |
| tuple: Updated UI components | |
| """ | |
| global _hf_loader | |
| try: | |
| # Close old loader | |
| if _hf_loader is not None: | |
| del _hf_loader | |
| # Create new loader with selected subset | |
| dataset_name = DATASET_CONFIG['hf_dataset'] | |
| split = DATASET_CONFIG['hf_split'] | |
| _hf_loader = HFDatasetLoader(dataset_name, split, subset) | |
| max_idx = get_max_index() | |
| # Update dataset info | |
| info_md = f""" | |
| # SimEIT: Dataset Visualizer | |
| **Dataset:** `{get_dataset_info()}` | **Total Samples:** {max_idx + 1:,} | |
| """ | |
| # Determine which sample to display | |
| if state and len(state) > 0: | |
| last_index = state[-1] | |
| sample_index = last_index if last_index <= max_idx else random.randint(0, max_idx) | |
| else: | |
| sample_index = random.randint(0, max_idx) | |
| # Update state with the new sample | |
| new_state = [sample_index] | |
| # Generate plots for the sample | |
| image_plot = create_heatmap_plot(image_res, sample_index, colorscale) | |
| volt_plot = draw_voltage_plot(sample_index) | |
| status_msg = f"✓ Loaded subset: {subset} ({max_idx + 1:,} samples) - Displaying sample {sample_index}" | |
| return ( | |
| info_md, | |
| gr.Number(label=f"Enter an integer (0–{max_idx})", precision=0, value=sample_index), | |
| new_state, | |
| image_plot, | |
| volt_plot, | |
| status_msg | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.Markdown(), | |
| gr.Number(), | |
| [], | |
| go.Figure(), | |
| go.Figure(), | |
| f"✗ Error loading subset {subset}: {str(e)}" | |
| ) | |
| # ============================================================================ | |
| # MAIN APPLICATION | |
| # ============================================================================ | |
| # Global dataset loader instance | |
| _hf_loader = None | |
| def create_gradio_interface(): | |
| """ | |
| Create and configure the Gradio interface. | |
| Returns: | |
| gr.Blocks: Configured Gradio application | |
| """ | |
| global _hf_loader | |
| # Initialize configuration | |
| dataset_name = DATASET_CONFIG['hf_dataset'] | |
| split = DATASET_CONFIG['hf_split'] | |
| default_subset = DATASET_CONFIG['hf_subset'] | |
| # Initialize dataset loader with default subset | |
| _hf_loader = HFDatasetLoader(dataset_name, split, default_subset) | |
| with gr.Blocks(title="SimEIT Dataset Visualizer") as demo: | |
| # Header | |
| dataset_info_display = gr.Markdown(f""" | |
| # SimEIT: Dataset Visualizer | |
| **Dataset:** `{get_dataset_info()}` | **Total Samples:** {get_max_index() + 1:,} | |
| """) | |
| # Controls section | |
| gr.Markdown("### Choose dataset subset, sample index, image resolution, and colormap") | |
| with gr.Row(): | |
| with gr.Column(): | |
| subset_selector = gr.Dropdown( | |
| choices=AVAILABLE_SUBSETS, | |
| value=default_subset, | |
| label="Select Dataset Subset" | |
| ) | |
| user_input = gr.Number( | |
| label=f"Enter an integer (0–{get_max_index()})", | |
| precision=0 | |
| ) | |
| btn_select_n = gr.Button("Confirm Number") | |
| btn_random = gr.Button("Generate Random Number") | |
| with gr.Column(): | |
| image_selector = gr.Dropdown( | |
| choices=AVAILABLE_RESOLUTIONS, | |
| value='256', | |
| label="Select Image Resolution" | |
| ) | |
| colormap_dropdown = gr.Dropdown( | |
| choices=AVAILABLE_COLORMAPS, | |
| value='Jet', | |
| label="Select Colormap" | |
| ) | |
| # State for tracking indices | |
| indices_list = gr.State(value=[]) | |
| # Visualization plots | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2): | |
| image_plot = gr.Plot(label="Image Heatmap") | |
| with gr.Column(scale=2): | |
| volt_plot = gr.Plot(label="Voltage Plot") | |
| # Status output | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Event handlers | |
| subset_selector.change( | |
| fn=reload_dataset, | |
| inputs=[subset_selector, indices_list, image_selector, colormap_dropdown], | |
| outputs=[dataset_info_display, user_input, indices_list, image_plot, volt_plot, status_output] | |
| ) | |
| btn_random.click( | |
| fn=generate_random_and_show, | |
| inputs=[indices_list, image_selector, colormap_dropdown], | |
| outputs=[user_input, indices_list, image_plot, volt_plot, status_output] | |
| ) | |
| btn_select_n.click( | |
| fn=select_n_and_show, | |
| inputs=[user_input, indices_list, image_selector, colormap_dropdown], | |
| outputs=[indices_list, image_plot, volt_plot, status_output] | |
| ) | |
| # Allow Enter key to confirm the number | |
| user_input.submit( | |
| fn=select_n_and_show, | |
| inputs=[user_input, indices_list, image_selector, colormap_dropdown], | |
| outputs=[indices_list, image_plot, volt_plot, status_output] | |
| ) | |
| image_selector.change( | |
| fn=show_images, | |
| inputs=[indices_list, image_selector, colormap_dropdown], | |
| outputs=[image_plot, volt_plot, status_output] | |
| ) | |
| colormap_dropdown.change( | |
| fn=show_images, | |
| inputs=[indices_list, image_selector, colormap_dropdown], | |
| outputs=[image_plot, volt_plot, status_output] | |
| ) | |
| # Load a random example at startup | |
| demo.load( | |
| fn=generate_random_and_show, | |
| inputs=[indices_list, image_selector, colormap_dropdown], | |
| outputs=[user_input, indices_list, image_plot, volt_plot, status_output] | |
| ) | |
| # Citation section | |
| gr.HTML(""" | |
| <!--Citation --> | |
| <section class="section" id="Citation"> | |
| <div class="container is-max-desktop content"> | |
| <h2 class="title">Citation</h2> | |
| <pre><code>@article{ameen2025simeit, | |
| title={SimEIT: A Scalable Simulation Framework for Generating Large-Scale Electrical Impedance Tomography Datasets}, | |
| author={Ameen, Ayman A. and Mathis-Ullrich, Franziska and Kainz, Bernhard}, | |
| year={2025}, | |
| }</code></pre> | |
| </div> | |
| </section> | |
| """) | |
| return demo | |
| def main(): | |
| """Main entry point for the application.""" | |
| demo = create_gradio_interface() | |
| demo.launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |