SimEIT-demo / app.py
AymanAmeen's picture
SimEIT Demo
88675f1 verified
"""
SimEIT Dataset Visualizer
A Gradio-based application for visualizing EIT (Electrical Impedance Tomography) datasets
from Hugging Face Hub with interactive plots and configurations.
Author: Ayman A. Ameen
"""
import random
import numpy as np
import h5py
import gradio as gr
import plotly.graph_objects as go
from huggingface_hub import HfFileSystem
# ============================================================================
# CONFIGURATION
# ============================================================================
DATASET_CONFIG = {
'hf_dataset': 'AymanAmeen/SimEIT-dataset',
'hf_split': 'train',
'hf_subset': 'FourObjects', # Options: 'FourObjects' or 'CirclesOnly'
}
AVAILABLE_SUBSETS = ['FourObjects', 'CirclesOnly']
AVAILABLE_RESOLUTIONS = ['256', '128_log', '64_log', '32_log']
AVAILABLE_COLORMAPS = [
'Jet', 'Viridis', 'Plasma', 'Inferno', 'Magma', 'Cividis',
'Hot', 'Cool', 'RdBu', 'RdYlBu', 'Spectral', 'Turbo',
'Blues', 'Greens', 'Reds', 'YlOrRd', 'Portland', 'Picnic'
]
# ============================================================================
# DATA LOADER
# ============================================================================
class HFDatasetLoader:
"""
Loads samples from Hugging Face dataset HDF5 file via streaming.
Features:
- Streams data directly from Hugging Face Hub without downloading
- Implements LRU cache for frequently accessed samples
- Supports lazy loading of specific resolutions
"""
def __init__(self, dataset_name, split="train", subset="FourObjects", cache_size=50):
"""
Initialize the dataset loader.
Args:
dataset_name: Name of the HuggingFace dataset
split: Dataset split (default: "train")
subset: Dataset subset (e.g., "FourObjects", "CirclesOnly")
cache_size: Number of samples to cache (default: 50)
"""
self.dataset_name = dataset_name
self.split = split
self.subset = subset
self.cache_size = cache_size
self._cache = {}
self._cache_order = []
print(f"Connecting to dataset {dataset_name} (subset: {subset}) via streaming...")
# Initialize HuggingFace filesystem for streaming
self.fs = HfFileSystem()
self.h5_path = f"datasets/{dataset_name}/{subset}/dataset.h5"
# Open HDF5 file in streaming mode and keep it open
self._file_handle = self.fs.open(self.h5_path, 'rb')
self.h5file = h5py.File(self._file_handle, 'r')
# Get dataset size
self.num_samples = self.h5file['image']['256'].shape[2]
print(f"✓ Dataset connected successfully!")
print(f" Total samples: {self.num_samples:,}")
print(f" Cache enabled: storing last {cache_size} samples")
def __del__(self):
"""Clean up file handles on object destruction."""
if hasattr(self, 'h5file'):
self.h5file.close()
if hasattr(self, '_file_handle'):
self._file_handle.close()
def get_sample(self, index, image_resolution=None):
"""
Get a specific sample by index from the HDF5 file.
Args:
index: Sample index to load (0 to num_samples-1)
image_resolution: Specific resolution to load (e.g., '256', '128_log')
If None, loads all resolutions (slower)
Returns:
dict: Sample data containing voltage and image data
Raises:
ValueError: If index is out of range
"""
# Check if index is out of range and clamp to last sample
if index < 0 or index >= self.num_samples:
print(f"⚠ Index {index} out of range [0, {self.num_samples}), using last sample {self.num_samples - 1}")
index = self.num_samples - 1
# Create cache key based on index and resolution
cache_key = (index, image_resolution)
# Check if already in cache
if cache_key in self._cache:
print(f"✓ Cache hit for sample {index}, resolution {image_resolution}")
return self._cache[cache_key]
print(f"Loading sample {index}, resolution {image_resolution}...")
sample = {}
# Load voltage data (stored as [256, num_samples])
sample['volt_16'] = self.h5file['volt']['16'][:, index]
# Lazy load: only load the requested image resolution
if image_resolution:
sample[f'image_{image_resolution}'] = self.h5file['image'][image_resolution][:, :, index]
else:
# Load all resolutions (backward compatibility)
for res in AVAILABLE_RESOLUTIONS:
sample[f'image_{res}'] = self.h5file['image'][res][:, :, index]
# Add to cache
self._add_to_cache(cache_key, sample)
return sample
def _add_to_cache(self, key, value):
"""
Add item to cache with LRU (Least Recently Used) eviction.
Args:
key: Cache key (tuple of index and resolution)
value: Sample data to cache
"""
if key in self._cache:
# Move to end (most recent)
self._cache_order.remove(key)
self._cache_order.append(key)
else:
# Add new item
if len(self._cache) >= self.cache_size:
# Evict oldest item
oldest_key = self._cache_order.pop(0)
del self._cache[oldest_key]
self._cache[key] = value
self._cache_order.append(key)
# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================
def create_heatmap_plot(key, index=0, colorscale='Jet'):
"""
Create a Plotly heatmap from dataset image.
Args:
key: Image resolution key (e.g., '256', '128_log')
index: Sample index
colorscale: Plotly colorscale name
Returns:
plotly.graph_objects.Figure: Heatmap figure
"""
global _hf_loader
try:
# Lazy load: only fetch the specific resolution needed
sample = _hf_loader.get_sample(index, image_resolution=key)
img = sample.get(f'image_{key}')
if img is None:
print(f"✗ Missing image_{key} in sample {index}")
return go.Figure()
# Convert to numpy array
img = np.array(img)
# Handle log-scaled images (negative values)
if len(img.shape) == 2 and np.min(img) < 0:
img = np.exp(img) # Convert from log back to linear
# If RGB image, convert to grayscale for heatmap
if len(img.shape) == 3 and img.shape[-1] == 3:
img = np.mean(img, axis=2)
# Normalize image values using mean and std for this sample
img_mean = np.mean(img)
img_std = np.std(img)
if img_std > 0: # Avoid division by zero
img_normalized = (img - img_mean) / img_std
else:
img_normalized = img - img_mean
# Create heatmap
fig = go.Figure(data=go.Heatmap(
z=img_normalized,
colorscale=colorscale,
showscale=True,
colorbar=dict(title="Normalized Conductivity")
))
fig.update_layout(
title=dict(text=f"{key} Image (Normalized) - Sample {index}", x=0.5, xanchor='center'),
width=450,
height=450,
xaxis=dict(showticklabels=False, showgrid=False),
yaxis=dict(showticklabels=False, showgrid=False, scaleanchor="x", scaleratio=1),
margin=dict(l=20, r=20, t=50, b=20),
autosize=False
)
return fig
except Exception as e:
print(f"✗ Error creating heatmap for image_{key}: {e}")
import traceback
traceback.print_exc()
return go.Figure()
def draw_voltage_plot(index=0):
"""
Draw voltage plot from dataset.
Args:
index: Sample index
Returns:
plotly.graph_objects.Figure: Voltage plot figure
"""
global _hf_loader
try:
# Load only voltage data (no images needed)
sample = _hf_loader.get_sample(index, image_resolution=None)
volt_data = sample.get('volt_16')
if volt_data is None:
print(f"✗ Missing voltage data in sample {index}")
return go.Figure()
volt_data = np.array(volt_data, dtype=np.float64)
if len(volt_data.shape) > 1:
volt_data = volt_data.flatten()
# Normalize voltage values using mean and std for this sample
volt_mean = np.mean(volt_data)
volt_std = np.std(volt_data)
if volt_std > 0: # Avoid division by zero
volt_normalized = (volt_data - volt_mean) / volt_std
else:
volt_normalized = volt_data - volt_mean
electrodes = np.arange(1, len(volt_normalized) + 1)
# Create line plot
fig = go.Figure()
fig.add_trace(go.Scatter(
x=electrodes,
y=volt_normalized,
mode='lines+markers',
marker=dict(size=6, color='royalblue'),
line=dict(width=2, color='royalblue')
))
fig.update_layout(
title=dict(text=f"Voltage Measurement (Normalized) - Sample {index}", x=0.5, xanchor='center'),
xaxis_title="Electrode Number (n)",
yaxis_title="Normalized Voltage (a.u.)",
template="plotly_white",
showlegend=False,
width=450,
height=450,
margin=dict(l=60, r=20, t=50, b=50),
autosize=False
)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
return fig
except Exception as e:
print(f"✗ Error plotting voltage: {e}")
return go.Figure()
# ============================================================================
# GRADIO UI HELPER FUNCTIONS
# ============================================================================
def get_dataset_info():
"""Get current dataset information string."""
global _hf_loader
return f"HuggingFace: {DATASET_CONFIG['hf_dataset']} (subset: {_hf_loader.subset}, split: {DATASET_CONFIG['hf_split']})"
def get_max_index():
"""Get maximum valid index in current dataset."""
global _hf_loader
return _hf_loader.num_samples - 1
def generate_random_index(state):
"""
Generate a random valid index and update state.
Args:
state: Current state list of indices
Returns:
tuple: (random_index, updated_state)
"""
num = random.randint(0, get_max_index())
new_list = state + [num]
return num, new_list
def select_index(n, state):
"""
Select a specific index with validation.
Args:
n: Index to select
state: Current state list of indices
Returns:
tuple: (validated_index, updated_state)
"""
if n is None or n == "":
return "", state
max_idx = get_max_index()
if not (0 <= n <= max_idx):
return f"Number must be between 0 and {max_idx}.", state
new_list = state + [int(n)]
return int(n), new_list
def show_images(state, image_res, colorscale):
"""
Display images for the last selected index in state.
Args:
state: State list containing selected indices
image_res: Image resolution to display
colorscale: Colorscale for heatmap
Returns:
tuple: (image_plot, voltage_plot, status_message)
"""
if not state:
return go.Figure(), go.Figure(), "No index selected"
last_index = state[-1]
return (
create_heatmap_plot(image_res, last_index, colorscale),
draw_voltage_plot(last_index),
f"✓ Loaded sample {last_index} with {image_res} resolution and colormap: {colorscale}"
)
def generate_random_and_show(state, image_res, colorscale):
"""Generate random index and show corresponding images."""
num, new_list = generate_random_index(state)
outputs = show_images(new_list, image_res, colorscale)
return (num, new_list) + outputs
def select_n_and_show(n, state, image_res, colorscale):
"""Select specific index and show corresponding images."""
_, new_list = select_index(n, state)
outputs = show_images(new_list, image_res, colorscale)
return (new_list,) + outputs
def reload_dataset(subset, state, image_res, colorscale):
"""
Reload the dataset with a new subset and display a sample.
Args:
subset: New subset to load
state: Current state list
image_res: Image resolution
colorscale: Colorscale for heatmap
Returns:
tuple: Updated UI components
"""
global _hf_loader
try:
# Close old loader
if _hf_loader is not None:
del _hf_loader
# Create new loader with selected subset
dataset_name = DATASET_CONFIG['hf_dataset']
split = DATASET_CONFIG['hf_split']
_hf_loader = HFDatasetLoader(dataset_name, split, subset)
max_idx = get_max_index()
# Update dataset info
info_md = f"""
# SimEIT: Dataset Visualizer
**Dataset:** `{get_dataset_info()}` | **Total Samples:** {max_idx + 1:,}
"""
# Determine which sample to display
if state and len(state) > 0:
last_index = state[-1]
sample_index = last_index if last_index <= max_idx else random.randint(0, max_idx)
else:
sample_index = random.randint(0, max_idx)
# Update state with the new sample
new_state = [sample_index]
# Generate plots for the sample
image_plot = create_heatmap_plot(image_res, sample_index, colorscale)
volt_plot = draw_voltage_plot(sample_index)
status_msg = f"✓ Loaded subset: {subset} ({max_idx + 1:,} samples) - Displaying sample {sample_index}"
return (
info_md,
gr.Number(label=f"Enter an integer (0–{max_idx})", precision=0, value=sample_index),
new_state,
image_plot,
volt_plot,
status_msg
)
except Exception as e:
return (
gr.Markdown(),
gr.Number(),
[],
go.Figure(),
go.Figure(),
f"✗ Error loading subset {subset}: {str(e)}"
)
# ============================================================================
# MAIN APPLICATION
# ============================================================================
# Global dataset loader instance
_hf_loader = None
def create_gradio_interface():
"""
Create and configure the Gradio interface.
Returns:
gr.Blocks: Configured Gradio application
"""
global _hf_loader
# Initialize configuration
dataset_name = DATASET_CONFIG['hf_dataset']
split = DATASET_CONFIG['hf_split']
default_subset = DATASET_CONFIG['hf_subset']
# Initialize dataset loader with default subset
_hf_loader = HFDatasetLoader(dataset_name, split, default_subset)
with gr.Blocks(title="SimEIT Dataset Visualizer") as demo:
# Header
dataset_info_display = gr.Markdown(f"""
# SimEIT: Dataset Visualizer
**Dataset:** `{get_dataset_info()}` | **Total Samples:** {get_max_index() + 1:,}
""")
# Controls section
gr.Markdown("### Choose dataset subset, sample index, image resolution, and colormap")
with gr.Row():
with gr.Column():
subset_selector = gr.Dropdown(
choices=AVAILABLE_SUBSETS,
value=default_subset,
label="Select Dataset Subset"
)
user_input = gr.Number(
label=f"Enter an integer (0–{get_max_index()})",
precision=0
)
btn_select_n = gr.Button("Confirm Number")
btn_random = gr.Button("Generate Random Number")
with gr.Column():
image_selector = gr.Dropdown(
choices=AVAILABLE_RESOLUTIONS,
value='256',
label="Select Image Resolution"
)
colormap_dropdown = gr.Dropdown(
choices=AVAILABLE_COLORMAPS,
value='Jet',
label="Select Colormap"
)
# State for tracking indices
indices_list = gr.State(value=[])
# Visualization plots
with gr.Row(equal_height=True):
with gr.Column(scale=2):
image_plot = gr.Plot(label="Image Heatmap")
with gr.Column(scale=2):
volt_plot = gr.Plot(label="Voltage Plot")
# Status output
status_output = gr.Textbox(label="Status", interactive=False)
# Event handlers
subset_selector.change(
fn=reload_dataset,
inputs=[subset_selector, indices_list, image_selector, colormap_dropdown],
outputs=[dataset_info_display, user_input, indices_list, image_plot, volt_plot, status_output]
)
btn_random.click(
fn=generate_random_and_show,
inputs=[indices_list, image_selector, colormap_dropdown],
outputs=[user_input, indices_list, image_plot, volt_plot, status_output]
)
btn_select_n.click(
fn=select_n_and_show,
inputs=[user_input, indices_list, image_selector, colormap_dropdown],
outputs=[indices_list, image_plot, volt_plot, status_output]
)
# Allow Enter key to confirm the number
user_input.submit(
fn=select_n_and_show,
inputs=[user_input, indices_list, image_selector, colormap_dropdown],
outputs=[indices_list, image_plot, volt_plot, status_output]
)
image_selector.change(
fn=show_images,
inputs=[indices_list, image_selector, colormap_dropdown],
outputs=[image_plot, volt_plot, status_output]
)
colormap_dropdown.change(
fn=show_images,
inputs=[indices_list, image_selector, colormap_dropdown],
outputs=[image_plot, volt_plot, status_output]
)
# Load a random example at startup
demo.load(
fn=generate_random_and_show,
inputs=[indices_list, image_selector, colormap_dropdown],
outputs=[user_input, indices_list, image_plot, volt_plot, status_output]
)
# Citation section
gr.HTML("""
<!--Citation -->
<section class="section" id="Citation">
<div class="container is-max-desktop content">
<h2 class="title">Citation</h2>
<pre><code>@article{ameen2025simeit,
title={SimEIT: A Scalable Simulation Framework for Generating Large-Scale Electrical Impedance Tomography Datasets},
author={Ameen, Ayman A. and Mathis-Ullrich, Franziska and Kainz, Bernhard},
year={2025},
}</code></pre>
</div>
</section>
""")
return demo
def main():
"""Main entry point for the application."""
demo = create_gradio_interface()
demo.launch(share=True)
if __name__ == "__main__":
main()