anway's picture
h5ad_viewer
05fdb87 verified
import gradio as gr
import os
import io
import zipfile
import tempfile
import csv
import datetime
from pathlib import Path
from typing import Optional, Tuple, List, Dict
import numpy as np
import plotly.graph_objects as go
from utils.loader import H5adLoader
from utils.validator import AnnDataValidator
from utils.plot import SpatialPlotter, SpatialImageExtractor
from utils.data_source_manager import DataSourceManager
class SpatialViewer:
"""Main application class for spatial transcriptomics viewer"""
# Default demo dataset to load on startup
DEFAULT_DEMO = "Cerebellum-MALDI-MSI.h5ad"
def __init__(self):
self.data_manager = DataSourceManager()
self.current_source = None
def load_default_demo(self) -> Tuple[str, Optional[gr.Plot], gr.update, gr.update, str]:
"""
Load default demo dataset on app startup
Returns:
Tuple of (status, overview_plot, selector_update, row_visibility, dataset_info)
"""
demo_path = Path("data") / self.DEFAULT_DEMO
if not demo_path.exists():
return (
"Demo dataset not found. Please load data manually.",
None,
gr.update(),
gr.update(visible=False),
"No dataset loaded"
)
try:
adata = H5adLoader.load_from_source(str(demo_path))
# Validate data
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
return (
"Demo dataset validation failed: " + "; ".join(errors),
None,
gr.update(),
gr.update(visible=False),
"No dataset loaded"
)
# Add to data manager
source_id = self.data_manager.add_source(
name=self.DEFAULT_DEMO,
source_type="demo",
source_path=str(demo_path),
adata=adata
)
# Create overview plot
spatial_coords = adata.obsm["spatial"]
overview_fig = SpatialPlotter.create_overview_plot(spatial_coords)
status = (
f"βœ… Auto-loaded demo dataset!\n"
f"- Dataset: {self.DEFAULT_DEMO}\n"
f"- Observations (spots/cells): {adata.n_obs:,}\n"
f"- Variables (genes): {adata.n_vars:,}\n"
f"- Spatial coordinates: {spatial_coords.shape}\n"
f"\nReady to visualize gene expression. Switch to 'Visualize Gene' tab."
)
# Dataset selector update
choices = self.data_manager.get_source_choices()
selector_update = gr.update(
choices=choices,
value=self.data_manager.current_id,
visible=True
)
# Dataset info for Visualize tab
current_source = self.data_manager.get_current_source()
dataset_info = f"πŸ“Š Current: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)"
return status, overview_fig, selector_update, gr.update(visible=True), dataset_info
except Exception as e:
return (
f"Failed to load demo dataset: {str(e)}",
None,
gr.update(),
gr.update(visible=False),
"No dataset loaded"
)
def load_data(
self, source_type: str, demo_dataset: Optional[str] = None, url: Optional[str] = None, file_path: Optional[str] = None
) -> Tuple[str, Optional[gr.Plot], gr.update]:
"""
Load h5ad data from various sources
Now supports ZIP files containing multiple h5ad files
Args:
source_type: Type of source ('demo', 'url', 'upload')
demo_dataset: Selected demo dataset name (if source_type is 'demo')
url: URL to h5ad file (if source_type is 'url')
file_path: Path to uploaded file (if source_type is 'upload')
Returns:
Tuple of (status_message, overview_plot, dataset_selector_update)
"""
try:
# Determine source
if source_type == "demo":
if not demo_dataset:
return "Please select a demo dataset.", None, gr.update()
demo_path = Path("data") / demo_dataset
if not demo_path.exists():
return f"Demo dataset not found: {demo_dataset}", None, gr.update()
source = str(demo_path)
display_name = demo_dataset
elif source_type == "url":
if not url or url.strip() == "":
return "Please provide a valid URL.", None, gr.update()
source = url.strip()
display_name = source.split("/")[-1] or "URL Dataset"
elif source_type == "upload":
if not file_path:
return "Please upload a file.", None, gr.update()
source = file_path
display_name = Path(file_path).name
else:
return f"Unknown source type: {source_type}", None, gr.update()
# Load data
loaded_data = H5adLoader.load_from_source(source)
# Handle multiple datasets (from ZIP file)
if isinstance(loaded_data, list):
# Multiple h5ad files loaded from ZIP
status_messages = []
loaded_count = 0
for idx, adata in enumerate(loaded_data):
# Validate each dataset
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
status_messages.append(
f"Dataset {idx + 1} validation failed:\n" + "\n".join(f" - {e}" for e in errors)
)
continue
# Add to data manager
file_name = f"{display_name} - Part {idx + 1}"
source_id = self.data_manager.add_source(
name=file_name,
source_type=source_type,
source_path=source,
adata=adata
)
loaded_count += 1
if loaded_count == 0:
return "No valid datasets found in ZIP file.\n" + "\n".join(status_messages), None, gr.update()
# Get current (latest loaded) dataset
current_source = self.data_manager.get_current_source()
spatial_coords = current_source.adata.obsm["spatial"]
overview_fig = SpatialPlotter.create_overview_plot(spatial_coords)
status = (
f"Successfully loaded {loaded_count} dataset(s) from ZIP file!\n\n"
f"Current dataset: {current_source.name}\n"
f"- Observations (spots/cells): {current_source.n_obs:,}\n"
f"- Variables (genes): {current_source.n_vars:,}\n"
f"- Spatial coordinates: {spatial_coords.shape}\n"
f"\nUse the dataset selector above to switch between datasets.\n"
f"Ready to visualize gene expression."
)
else:
# Single h5ad file
adata = loaded_data
# Validate data
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
error_msg = "Validation errors:\n" + "\n".join(f"- {e}" for e in errors)
return error_msg, None, gr.update()
# Add to data manager
source_id = self.data_manager.add_source(
name=display_name,
source_type=source_type,
source_path=source,
adata=adata
)
# Create overview plot
spatial_coords = adata.obsm["spatial"]
overview_fig = SpatialPlotter.create_overview_plot(spatial_coords)
status = (
f"Successfully loaded data!\n"
f"- Dataset: {display_name}\n"
f"- Observations (spots/cells): {adata.n_obs:,}\n"
f"- Variables (genes): {adata.n_vars:,}\n"
f"- Spatial coordinates: {spatial_coords.shape}\n"
f"\nReady to visualize gene expression."
)
# Update dataset selector
choices = self.data_manager.get_source_choices()
selector_update = gr.update(
choices=choices,
value=self.data_manager.current_id,
visible=True
)
return status, overview_fig, selector_update
except Exception as e:
return f"Error loading data: {str(e)}", None, gr.update()
def switch_dataset(self, source_id: str) -> Tuple[str, Optional[gr.Plot]]:
"""
Switch to a different loaded dataset
Args:
source_id: ID of the dataset to switch to
Returns:
Tuple of (info_message, overview_plot)
"""
if not source_id:
return "No dataset selected.", None
success = self.data_manager.set_current(source_id)
if not success:
return f"Dataset not found: {source_id}", None
current_source = self.data_manager.get_current_source()
spatial_coords = current_source.adata.obsm["spatial"]
overview_fig = SpatialPlotter.create_overview_plot(spatial_coords)
info = current_source.get_info()
return info, overview_fig
def visualize_gene(
self,
gene_name: str,
point_size: int = 5,
use_log: bool = True,
colorscale: str = "Viridis",
show_background: bool = False,
background_opacity: float = 0.5,
) -> Tuple[str, Optional[gr.Plot], str, str]:
"""
Visualize gene expression in spatial context
"""
current_source = self.data_manager.get_current_source()
if current_source is None:
return "❌ Please load data first.", None, "", ""
if current_source.adata is None:
return "❌ Dataset registered but not yet loaded. Please select it in the 'Select Dataset' tab first.", None, "", ""
if not gene_name or gene_name.strip() == "":
return "❓ Please enter a gene name.", None, "", ""
gene_name = gene_name.strip()
try:
adata = current_source.adata
# Get gene expression
expression = AnnDataValidator.get_gene_expression(adata, gene_name)
# Get spatial coordinates
spatial_coords = adata.obsm["spatial"]
# Extract background image from h5ad if requested
background_image = None
scalefactors = None
bg_status = ""
if show_background:
result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True)
if result is not None:
background_image, scalefactors, image_key = result
# Pass image_key to scalefactors so plot knows which scale to use
scalefactors = dict(scalefactors) # Make a copy
scalefactors['_image_key'] = image_key
bg_status = f" (with {image_key} tissue background)"
else:
bg_status = " (no background image in h5ad)"
# Create plot
fig = SpatialPlotter.plot_spatial_gene(
spatial_coords=spatial_coords,
expression=expression,
gene_name=gene_name,
point_size=point_size,
use_log=use_log,
colorscale=colorscale,
background_image=background_image,
scalefactors=scalefactors,
background_opacity=background_opacity,
)
# Get statistics
stats = SpatialPlotter.get_expression_stats(expression)
stats_text = (
f"Expression Statistics for {gene_name}:\n"
f"- Min: {stats['min']:.4f}\n"
f"- Max: {stats['max']:.4f}\n"
f"- Mean: {stats['mean']:.4f}\n"
f"- Median: {stats['median']:.4f}\n"
f"- Std Dev: {stats['std']:.4f}\n"
f"- Non-zero: {stats['non_zero_count']:,} ({stats['non_zero_percent']:.1f}%)"
)
# Current dataset info
dataset_info = f"Current dataset: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)"
return f"Successfully visualized gene: {gene_name}{bg_status}", fig, stats_text, dataset_info
except ValueError as e:
return str(e), None, "", ""
except Exception as e:
return f"Error visualizing gene: {str(e)}", None, "", ""
def check_spatial_image_available(self) -> bool:
"""Check if current dataset has spatial background image"""
current_source = self.data_manager.get_current_source()
if current_source is None or current_source.adata is None:
return False
return SpatialImageExtractor.has_spatial_image(current_source.adata)
def get_gene_suggestions(self, limit: int = 100) -> list:
"""Get list of available genes for autocomplete"""
current_source = self.data_manager.get_current_source()
if current_source is None or current_source.adata is None:
return []
return AnnDataValidator.get_gene_list(current_source.adata, limit=limit)
def get_current_dataset_info(self) -> str:
"""Get formatted info string for current dataset"""
current_source = self.data_manager.get_current_source()
if current_source is None:
return "No dataset loaded. Please load data first."
if current_source.adata is None:
return f"πŸ“Š Current: {current_source.name}\n(Not yet loaded)"
return f"πŸ“Š Current: {current_source.name}\n({current_source.n_obs:,} cells, {current_source.n_vars:,} genes)"
def get_all_genes(self) -> List[str]:
"""Get full list of genes for autocomplete dropdown"""
current_source = self.data_manager.get_current_source()
if current_source is None or current_source.adata is None:
return []
return list(current_source.adata.var_names)
def search_genes(self, query: str, limit: int = 50) -> List[str]:
"""
Search genes by prefix or substring match
"""
current_source = self.data_manager.get_current_source()
if current_source is None or current_source.adata is None:
return []
if not query or query.strip() == "":
# Return first N genes if no query
return list(current_source.adata.var_names[:limit])
query = query.strip().upper()
all_genes = list(current_source.adata.var_names)
# First: exact prefix matches (prioritized)
prefix_matches = [g for g in all_genes if g.upper().startswith(query)]
# Second: substring matches (lower priority)
substring_matches = [g for g in all_genes if query in g.upper() and g not in prefix_matches]
# Combine and limit
results = prefix_matches + substring_matches
return results[:limit]
def get_adata_summary(self) -> str:
"""
Get detailed summary of current AnnData object
Returns:
Formatted string with h5ad file details
"""
current_source = self.data_manager.get_current_source()
if current_source is None:
return "No dataset loaded"
if current_source.adata is None:
return f"πŸ“Š **{current_source.name}**\n\n*Dataset registered but not yet loaded. Select it in the list to load.*"
adata = current_source.adata
lines = []
lines.append(f"πŸ“Š **{current_source.name}**")
lines.append("")
# Basic info
lines.append("### πŸ“ˆ Dimensions")
lines.append(f"- Observations (cells/spots): **{adata.n_obs:,}**")
lines.append(f"- Variables (features): **{adata.n_vars:,}**")
# Spatial coordinates
if "spatial" in adata.obsm:
spatial_shape = adata.obsm["spatial"].shape
lines.append(f"- Spatial coordinates: **{spatial_shape}**")
lines.append("")
# Variables info (first 5)
lines.append("### 🧬 Variables (first 5)")
var_names = list(adata.var_names[:5])
lines.append(f"`{', '.join(var_names)}`")
if adata.n_vars > 5:
lines.append(f"... and {adata.n_vars - 5:,} more")
lines.append("")
# obsm keys
if len(adata.obsm.keys()) > 0:
lines.append("### πŸ“ obsm (embeddings)")
for key in list(adata.obsm.keys())[:5]:
shape = adata.obsm[key].shape
lines.append(f"- `{key}`: {shape}")
# obsp keys
if hasattr(adata, 'obsp') and len(adata.obsp.keys()) > 0:
lines.append("")
lines.append("### πŸ”— obsp (pairwise)")
for key in list(adata.obsp.keys())[:3]:
lines.append(f"- `{key}`")
# uns keys
if len(adata.uns.keys()) > 0:
lines.append("")
lines.append("### πŸ“¦ uns (unstructured)")
uns_keys = list(adata.uns.keys())[:6]
lines.append(f"`{', '.join(uns_keys)}`")
if len(adata.uns.keys()) > 6:
lines.append(f"... and {len(adata.uns.keys()) - 6} more")
# Check for spatial image
lines.append("")
lines.append("### πŸ–ΌοΈ Spatial Image")
if SpatialImageExtractor.has_spatial_image(adata):
libs = SpatialImageExtractor.get_available_libraries(adata)
lines.append(f"βœ… Available (libraries: {', '.join(libs)})")
else:
lines.append("❌ Not available")
return "\n".join(lines)
def get_local_h5ad_files(self) -> List[str]:
"""Get list of h5ad files in the data folder"""
data_dir = Path("data")
if not data_dir.exists():
return []
return [f.name for f in data_dir.glob("*.h5ad")]
def create_overview_with_background(self) -> Optional[go.Figure]:
"""Create spatial overview plot with tissue background if available"""
current_source = self.data_manager.get_current_source()
if current_source is None or current_source.adata is None:
return None
adata = current_source.adata
spatial_coords = adata.obsm["spatial"]
# Try to get background image
background_image = None
scalefactors = None
result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True)
if result is not None:
background_image, scalefactors, image_key = result
scalefactors = dict(scalefactors)
scalefactors['_image_key'] = image_key
# Create overview plot with background
return SpatialPlotter.create_overview_plot_with_background(
spatial_coords=spatial_coords,
background_image=background_image,
scalefactors=scalefactors,
)
def parse_variables_list(self, input_text: str) -> Tuple[List[str], List[str], List[str]]:
"""
Parse comma/space/newline separated variables list
Args:
input_text: Raw input text with variable names
Returns:
Tuple of (found_features, not_found_features, all_parsed)
"""
current_source = self.data_manager.get_current_source()
if current_source is None:
return [], [], []
if not input_text or input_text.strip() == "":
return [], [], []
# Parse: split by comma, space, newline, tab
import re
raw_items = re.split(r'[,\s\n\t]+', input_text.strip())
all_parsed = [item.strip() for item in raw_items if item.strip()]
# Check which features exist in dataset
available_genes = set(current_source.adata.var_names)
found_features = [g for g in all_parsed if g in available_genes]
not_found_features = [g for g in all_parsed if g not in available_genes]
return found_features, not_found_features, all_parsed
def batch_visualize(
self,
variables_text: str,
point_size: int = 5,
use_log: bool = True,
colorscale: str = "Viridis",
show_background: bool = False,
background_opacity: float = 0.5,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[str, Optional[str], str, str]:
"""
Perform batch visualization for multiple features
Args:
variables_text: Comma/space/newline separated feature names
point_size, use_log, colorscale, show_background, background_opacity: Plot settings
progress: Gradio progress tracker
Returns:
Tuple of (status, zip_file_path, summary_report, stats_csv)
"""
current_source = self.data_manager.get_current_source()
if current_source is None:
return "❌ No dataset loaded. Please load data first.", None, "", ""
found_features, not_found_features, all_parsed = self.parse_variables_list(variables_text)
if not found_features:
return f"❌ No valid features found in dataset.\nParsed: {', '.join(all_parsed)}", None, "", ""
# Prepare output
adata = current_source.adata
spatial_coords = adata.obsm["spatial"]
# Get background image if needed
background_image = None
scalefactors = None
if show_background:
result = SpatialImageExtractor.get_spatial_image(adata, prefer_lowres=True)
if result is not None:
background_image, scalefactors, image_key = result
scalefactors = dict(scalefactors)
scalefactors['_image_key'] = image_key
# Create temp directory for outputs
temp_dir = tempfile.mkdtemp(prefix="batch_viz_")
# Track results
stats_records = []
successful_plots = []
failed_features = []
# Generate plots
total = len(found_features)
for idx, gene_name in enumerate(found_features):
progress((idx + 1) / total, desc=f"Processing {gene_name} ({idx + 1}/{total})")
try:
# Get expression
expression = AnnDataValidator.get_gene_expression(adata, gene_name)
# Create plot
fig = SpatialPlotter.plot_spatial_gene(
spatial_coords=spatial_coords,
expression=expression,
gene_name=gene_name,
point_size=point_size,
use_log=use_log,
colorscale=colorscale,
background_image=background_image,
scalefactors=scalefactors,
background_opacity=background_opacity,
)
# Save as PNG
png_path = os.path.join(temp_dir, f"{gene_name}.png")
fig.write_image(png_path, scale=2)
successful_plots.append((gene_name, png_path))
# Get statistics
stats = SpatialPlotter.get_expression_stats(expression)
stats['feature'] = gene_name
stats_records.append(stats)
except Exception as e:
failed_features.append((gene_name, str(e)))
# Generate summary report
report_lines = [
"# Batch Visualization Report",
f"Dataset: {current_source.name}",
f"Total cells/spots: {current_source.n_obs:,}",
f"Total features: {current_source.n_vars:,}",
"",
"## Settings",
f"- Point Size: {point_size}",
f"- Log Transform: {use_log}",
f"- Color Scale: {colorscale}",
f"- Background: {show_background}",
"",
"## Results Summary",
f"- Total requested: {len(all_parsed)}",
f"- Found in dataset: {len(found_features)}",
f"- Successfully visualized: {len(successful_plots)}",
f"- Failed: {len(failed_features)}",
"",
]
if not_found_features:
report_lines.append("## Not Found Features")
for feat in not_found_features:
report_lines.append(f"- {feat}")
report_lines.append("")
if failed_features:
report_lines.append("## Failed Features")
for feat, err in failed_features:
report_lines.append(f"- {feat}: {err}")
report_lines.append("")
report_lines.append("## Successfully Visualized Features")
for feat, _ in successful_plots:
report_lines.append(f"- {feat}")
report_text = "\n".join(report_lines)
# Save report
report_path = os.path.join(temp_dir, "report.md")
with open(report_path, "w") as f:
f.write(report_text)
# Save statistics CSV
stats_csv_path = os.path.join(temp_dir, "expression_statistics.csv")
if stats_records:
with open(stats_csv_path, "w", newline="") as f:
fieldnames = ['feature', 'min', 'max', 'mean', 'median', 'std', 'non_zero_count', 'non_zero_percent']
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(stats_records)
# Create ZIP file
zip_path = os.path.join(temp_dir, "batch_visualization.zip")
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
# Add images
for gene_name, png_path in successful_plots:
zf.write(png_path, f"images/{gene_name}.png")
# Add report
zf.write(report_path, "report.md")
# Add stats CSV
if stats_records:
zf.write(stats_csv_path, "expression_statistics.csv")
# Format stats for display
stats_display = "Feature | Min | Max | Mean | Non-zero %\n"
stats_display += "--- | --- | --- | --- | ---\n"
for rec in stats_records:
stats_display += f"{rec['feature']} | {rec['min']:.4f} | {rec['max']:.4f} | {rec['mean']:.4f} | {rec['non_zero_percent']:.1f}%\n"
status = f"βœ… Batch visualization complete!\n- Generated: {len(successful_plots)} plots\n- Failed: {len(failed_features)}"
return status, zip_path, report_text, stats_display
def create_interface():
"""Create Gradio interface"""
viewer = SpatialViewer()
# Custom CSS
custom_css = """
.duplicate-notice {
background: linear-gradient(135deg, #fff8e1 0%, #ffecb3 100%);
color: #3e2723;
border: 1px solid #ffc107;
border-radius: 8px;
padding: 12px 16px;
margin: 12px 0;
font-size: 0.95rem;
line-height: 1.5;
}
.duplicate-notice b { color: #e65100; }
@media (prefers-color-scheme: dark) {
.duplicate-notice {
background: linear-gradient(135deg, rgba(50,40,20,0.9) 0%, rgba(40,30,10,0.9) 100%);
color: #ffffff;
border-color: #ffc107;
}
.duplicate-notice b { color: #ffd54f; }
}
.file-browser {
background: linear-gradient(180deg, #f8f9fa 0%, #e9ecef 100%);
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 12px;
}
@media (prefers-color-scheme: dark) {
.file-browser {
background: linear-gradient(180deg, #2d2d2d 0%, #1a1a1a 100%);
border-color: #444;
}
}
.data-info-panel {
background: linear-gradient(180deg, #e3f2fd 0%, #bbdefb 100%);
border: 1px solid #90caf9;
border-radius: 8px;
padding: 12px;
}
@media (prefers-color-scheme: dark) {
.data-info-panel {
background: linear-gradient(180deg, rgba(33,150,243,0.15) 0%, rgba(33,150,243,0.05) 100%);
border-color: #1976d2;
}
}
.control-panel {
background: linear-gradient(180deg, #f5f5f5 0%, #eeeeee 100%);
border: 1px solid #e0e0e0;
border-radius: 8px;
padding: 16px;
}
@media (prefers-color-scheme: dark) {
.control-panel {
background: linear-gradient(180deg, #2a2a2a 0%, #1f1f1f 100%);
border-color: #444;
}
}
"""
with gr.Blocks(
title="Spatial Omics Viewer",
theme=gr.themes.Soft(),
css=custom_css,
) as app:
gr.Markdown(
"""
# πŸ”¬ Spatial Omics Viewer
Visualize spatial expression from .h5ad files (AnnData format)
<div class="duplicate-notice">
<b>Notice:</b> This is a public demo Space. For large h5ad files or heavy usage,
please <b>Duplicate this Space</b> to your account for better performance and privacy.
</div>
"""
)
# ==================== Select Dataset Tab ====================
with gr.Tab("πŸ“‚ Select Dataset"):
with gr.Row():
# Column 1: Dataset Browser
with gr.Column(scale=1, elem_classes="file-browser"):
gr.Markdown("### πŸ“ Available Datasets")
gr.Markdown("*Click to select and view*")
# All available datasets (loaded ones)
dataset_selector = gr.Radio(
choices=[],
label="πŸ“¦ Datasets",
value=None,
info="Click to select",
)
gr.Markdown("---")
gr.Markdown("#### πŸ“₯ Import New Data")
import_type = gr.Radio(
choices=["URL", "Upload"],
value="URL",
label="Import Method",
info="Download from URL or upload file",
)
with gr.Group() as url_group:
url_input = gr.Textbox(
label="πŸ”— URL",
placeholder="https://... or Google Drive link",
info="HuggingFace, Zenodo, S3, Google Drive",
lines=1,
)
import_url_btn = gr.Button("πŸ“₯ Import from URL", variant="secondary")
with gr.Group(visible=False) as upload_group:
file_input = gr.File(
label="πŸ“€ Upload File",
file_types=[".h5ad", ".zip"],
type="filepath",
)
load_status = gr.Textbox(
label="Status",
lines=2,
interactive=False,
)
# Column 2: Spatial Overview with background
with gr.Column(scale=2):
gr.Markdown("### πŸ—ΊοΈ Spatial Overview")
overview_plot = gr.Plot(label="Spatial Overview")
# Column 3: Dataset Info
with gr.Column(scale=1, elem_classes="data-info-panel"):
gr.Markdown("### πŸ“Š Dataset Information")
dataset_summary = gr.Markdown(
value="*Select a dataset to see information*",
elem_id="dataset-summary",
)
# ==================== Visualize Tab ====================
with gr.Tab("🎨 Visualize") as visualize_tab:
with gr.Row():
# Column 1: Controls
with gr.Column(scale=1, elem_classes="control-panel"):
gr.Markdown("### βš™οΈ Controls")
gr.Markdown("*Auto-renders when parameters change*", elem_id="auto-render-hint")
# Current dataset
current_dataset_display = gr.Textbox(
label="πŸ“Š Current Dataset",
value="No dataset loaded",
interactive=False,
lines=2,
)
# Gene input
gene_input = gr.Textbox(
label="🧬 Feature Name",
placeholder="Type to search (e.g., Pcp, Gab, Act)",
info="Start typing to see matching features",
)
gene_quick_picks = gr.Radio(
label="πŸ” Quick Pick",
choices=[],
visible=False,
interactive=True,
)
# Plot Settings - default open
with gr.Accordion("πŸŽ›οΈ Plot Settings", open=True):
point_size = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Size",
)
use_log = gr.Checkbox(
value=True,
label="Use log1p transformation",
info="Recommended for better visualization",
)
colorscale = gr.Dropdown(
choices=[
"Viridis", "Plasma", "Inferno", "Magma",
"Cividis", "Blues", "Reds", "YlOrRd", "RdYlBu",
],
value="Viridis",
label="Color Scale",
)
# Tissue Background - default open
with gr.Accordion("πŸ–ΌοΈ Tissue Background", open=True):
show_background = gr.Checkbox(
value=False,
label="Show tissue background",
info="From h5ad file (if available)",
)
background_opacity = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.1,
label="Background Opacity",
)
# Column 2: Plot
with gr.Column(scale=2):
gr.Markdown("### πŸ”¬ Spatial Omics Expression")
gene_plot = gr.Plot(label="Spatial Omics Expression")
# Column 3: Stats
with gr.Column(scale=1):
gr.Markdown("### πŸ“ˆ Analysis")
vis_status = gr.Textbox(
label="Status",
lines=2,
interactive=False,
)
stats_output = gr.Textbox(
label="Expression Statistics",
lines=10,
interactive=False,
)
# ==================== Batch Visualize Tab ====================
with gr.Tab("πŸ“Š Batch Visualize") as batch_tab:
with gr.Row():
# Column 1: Input & Settings
with gr.Column(scale=1, elem_classes="control-panel"):
gr.Markdown("### πŸ“ Batch Input")
gr.Markdown("*Paste variable names (comma, space, or newline separated)*")
batch_current_dataset = gr.Textbox(
label="πŸ“Š Current Dataset",
value="No dataset loaded",
interactive=False,
lines=2,
)
batch_variables_input = gr.Textbox(
label="🧬 Paste Variables List",
placeholder="Gene1, Gene2, Gene3\nor\nGene1\nGene2\nGene3",
lines=10,
info="Supports comma, space, or newline separated values",
)
batch_parse_btn = gr.Button("πŸ” Parse & Preview", variant="secondary")
batch_parse_result = gr.Markdown(
value="*Enter variables and click Parse to preview*",
elem_id="batch-parse-result",
)
gr.Markdown("---")
gr.Markdown("### βš™οΈ Batch Settings")
with gr.Accordion("πŸŽ›οΈ Plot Settings", open=True):
batch_point_size = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Point Size",
)
batch_use_log = gr.Checkbox(
value=True,
label="Use log1p transformation",
)
batch_colorscale = gr.Dropdown(
choices=[
"Viridis", "Plasma", "Inferno", "Magma",
"Cividis", "Blues", "Reds", "YlOrRd", "RdYlBu",
],
value="Viridis",
label="Color Scale",
)
with gr.Accordion("πŸ–ΌοΈ Tissue Background", open=True):
batch_show_background = gr.Checkbox(
value=False,
label="Show tissue background",
)
batch_background_opacity = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.1,
label="Background Opacity",
)
batch_run_btn = gr.Button(
"πŸš€ Run Batch Visualization", variant="primary", size="lg"
)
# Column 2: Preview
with gr.Column(scale=2):
gr.Markdown("### πŸ‘οΈ Preview (First Found Feature)")
batch_preview_plot = gr.Plot(label="Preview")
batch_preview_status = gr.Textbox(
label="Preview Status",
lines=2,
interactive=False,
)
# Column 3: Results
with gr.Column(scale=1):
gr.Markdown("### πŸ“¦ Results")
batch_status = gr.Textbox(
label="Batch Status",
lines=3,
interactive=False,
)
batch_download = gr.File(
label="πŸ“₯ Download Results (ZIP)",
file_count="single",
interactive=False,
)
with gr.Accordion("πŸ“‹ Summary Report", open=True):
batch_report = gr.Markdown(
value="*Run batch visualization to see report*",
)
with gr.Accordion("πŸ“Š Expression Statistics", open=False):
batch_stats = gr.Markdown(
value="*Run batch visualization to see statistics*",
)
# ==================== About Tab ====================
with gr.Tab("ℹ️ About"):
gr.Markdown(
"""
## About This Tool
This tool visualizes spatial omics expression from AnnData (.h5ad) files.
### Features
- πŸš€ Auto-loads demo dataset on startup
- πŸ” Feature name autocomplete search
- πŸ”— Load from URLs (HuggingFace, Zenodo, S3, Google Drive)
- πŸ“€ Upload h5ad/ZIP files
- πŸ–ΌοΈ Tissue background image overlay
- πŸ“Š Interactive Plotly visualization
- πŸ’Ύ Memory-efficient backed mode
### How to Use
1. **Load Data**: Select built-in dataset or import external data
2. **Visualize**: Search for features and visualize spatial expression
3. **Customize**: Adjust plot settings and background
### For Large Files
Please **Duplicate this Space** for large files (>2GB), frequent usage, or private data.
---
Built for the spatial omics research community.
"""
)
# ============================================
# Event bindings
# ============================================
# Import type toggle
def toggle_import_type(import_method):
return {
url_group: gr.update(visible=(import_method == "URL")),
upload_group: gr.update(visible=(import_method == "Upload")),
}
import_type.change(
toggle_import_type,
inputs=[import_type],
outputs=[url_group, upload_group],
)
# Switch dataset when clicking on selector
def switch_dataset(source_id):
"""Switch to selected dataset (load if needed) and update all views"""
if not source_id:
return "", None, "*Select a dataset*", viewer.get_current_dataset_info()
try:
# 1. Get source info
source = viewer.data_manager.get_source(source_id)
if source is None:
return f"❌ Dataset {source_id} not found", None, "", ""
# 2. Lazy load if not already loaded
if source.adata is None:
print(f"DEBUG: Lazy loading {source.name} from {source.source_path}")
# Free up memory from other datasets first
import gc
for other_id, other_source in viewer.data_manager.sources.items():
if other_id != source_id and other_source.adata is not None:
print(f"DEBUG: Freeing memory from {other_source.name}")
other_source.adata = None
gc.collect()
# Load current
adata = H5adLoader.load_from_source(source.source_path)
# Validate
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
return f"❌ Validation failed: {'; '.join(errors)}", None, "", ""
# Update source object
source.adata = adata
source.n_obs = adata.n_obs
source.n_vars = adata.n_vars
source.loaded_at = datetime.datetime.now()
# 3. Set as current
viewer.data_manager.set_current(source_id)
# 4. Update all views
overview_fig = viewer.create_overview_with_background()
summary = viewer.get_adata_summary()
dataset_info = viewer.get_current_dataset_info()
choices = viewer.data_manager.get_source_choices()
# Update selector choices to show cell/gene counts
selector_update = gr.update(choices=choices, value=source_id)
return f"βœ… Loaded: {source.name}", overview_fig, summary, dataset_info, selector_update
except Exception as e:
import traceback
print(traceback.format_exc())
return f"❌ Error loading dataset: {str(e)}", None, "", "", gr.update()
dataset_selector.change(
switch_dataset,
inputs=[dataset_selector],
outputs=[load_status, overview_plot, dataset_summary, current_dataset_display, dataset_selector],
)
# Import from URL
def import_from_url(url):
"""Import dataset from URL"""
if not url or not url.strip():
return "❌ Please enter a URL", None, "", gr.update(), ""
url = url.strip()
display_name = url.split("/")[-1].split("?")[0] or "URL Dataset"
try:
# Clear existing memory-heavy data before loading new one
import gc
for source in viewer.data_manager.sources.values():
source.adata = None
gc.collect()
loaded_data = H5adLoader.load_from_source(url)
if not isinstance(loaded_data, list):
loaded_data = [loaded_data]
last_id = None
for idx, adata in enumerate(loaded_data):
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
return f"❌ Validation failed: {'; '.join(errors)}", None, "", gr.update(), ""
name = display_name if len(loaded_data) == 1 else f"{display_name} - Part {idx + 1}"
last_id = viewer.data_manager.add_source(
name=name,
source_type="url",
source_path=url,
adata=adata
)
# Set the last imported one as current
if last_id:
viewer.data_manager.set_current(last_id)
# Update views
overview_fig = viewer.create_overview_with_background()
summary = viewer.get_adata_summary()
choices = viewer.data_manager.get_source_choices()
selector_update = gr.update(choices=choices, value=viewer.data_manager.current_id)
dataset_info = viewer.get_current_dataset_info()
return f"βœ… Imported: {display_name}", overview_fig, summary, selector_update, dataset_info
except Exception as e:
return f"❌ Error: {str(e)}", None, "", gr.update(), ""
import_url_btn.click(
import_from_url,
inputs=[url_input],
outputs=[load_status, overview_plot, dataset_summary, dataset_selector, current_dataset_display],
)
# Upload file
def upload_file(uploaded_file):
"""Handle file upload"""
if not uploaded_file:
return "❌ No file uploaded", None, "", gr.update(), ""
display_name = Path(uploaded_file).name
try:
# Clear existing memory-heavy data
import gc
for source in viewer.data_manager.sources.values():
source.adata = None
gc.collect()
loaded_data = H5adLoader.load_from_source(uploaded_file)
if not isinstance(loaded_data, list):
loaded_data = [loaded_data]
last_id = None
for idx, adata in enumerate(loaded_data):
is_valid, errors = AnnDataValidator.validate(adata)
if not is_valid:
return f"❌ Validation failed: {'; '.join(errors)}", None, "", gr.update(), ""
name = display_name if len(loaded_data) == 1 else f"{display_name} - Part {idx + 1}"
last_id = viewer.data_manager.add_source(
name=name,
source_type="upload",
source_path=uploaded_file,
adata=adata
)
# Set as current
if last_id:
viewer.data_manager.set_current(last_id)
# Update views
overview_fig = viewer.create_overview_with_background()
summary = viewer.get_adata_summary()
choices = viewer.data_manager.get_source_choices()
selector_update = gr.update(choices=choices, value=viewer.data_manager.current_id)
dataset_info = viewer.get_current_dataset_info()
return f"βœ… Uploaded: {display_name}", overview_fig, summary, selector_update, dataset_info
except Exception as e:
return f"❌ Error: {str(e)}", None, "", gr.update(), ""
file_input.change(
upload_file,
inputs=[file_input],
outputs=[load_status, overview_plot, dataset_summary, dataset_selector, current_dataset_display],
)
# Visualize tab events
def update_on_tab_select():
return viewer.get_current_dataset_info()
visualize_tab.select(
update_on_tab_select,
inputs=[],
outputs=[current_dataset_display],
)
def live_search(query):
if not query or len(query.strip()) < 2:
return gr.update(choices=[], visible=False)
results = viewer.search_genes(query, limit=15)
if results:
return gr.update(choices=results, visible=True, value=None)
return gr.update(choices=[], visible=False)
gene_input.change(
live_search,
inputs=[gene_input],
outputs=[gene_quick_picks],
)
def quick_visualize(selected_gene, point_size, use_log, colorscale, show_bg, bg_opacity):
if not selected_gene:
return gr.update(), None, "", "", gr.update(visible=False), ""
status, plot, stats, dataset_info = viewer.visualize_gene(
selected_gene, point_size, use_log, colorscale, show_bg, bg_opacity
)
return selected_gene, plot, stats, dataset_info, gr.update(visible=False), status
gene_quick_picks.change(
quick_visualize,
inputs=[gene_quick_picks, point_size, use_log, colorscale, show_background, background_opacity],
outputs=[gene_input, gene_plot, stats_output, current_dataset_display, gene_quick_picks, vis_status],
)
# Auto-render when any parameter changes
def auto_visualize(gene_name, pt_size, log_transform, color_scale, show_bg, bg_opacity):
"""Auto-render visualization when parameters change"""
if not gene_name or gene_name.strip() == "":
return gr.update(), gr.update(), gr.update(), ""
status, plot, stats, dataset_info = viewer.visualize_gene(
gene_name, pt_size, log_transform, color_scale, show_bg, bg_opacity
)
return status, plot, stats, dataset_info
# Bind auto-render to all parameter changes
auto_render_inputs = [gene_input, point_size, use_log, colorscale, show_background, background_opacity]
auto_render_outputs = [vis_status, gene_plot, stats_output, current_dataset_display]
# Re-render on gene input blur (when user finishes typing)
gene_input.blur(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
# Re-render on parameter changes
point_size.release(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
use_log.change(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
colorscale.change(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
show_background.change(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
background_opacity.release(
auto_visualize,
inputs=auto_render_inputs,
outputs=auto_render_outputs,
)
# ============================================
# Batch Visualize Tab Events
# ============================================
def update_batch_dataset():
return viewer.get_current_dataset_info()
batch_tab.select(
update_batch_dataset,
inputs=[],
outputs=[batch_current_dataset],
)
def parse_and_preview(variables_text, pt_size, log_transform, color_scale, show_bg, bg_opacity):
"""Parse variables list and preview first found feature"""
found, not_found, all_parsed = viewer.parse_variables_list(variables_text)
# Build parse result message
result_lines = []
result_lines.append(f"**Parsed:** {len(all_parsed)} items")
result_lines.append(f"**Found:** {len(found)} features")
if found:
result_lines.append(f"- `{', '.join(found[:10])}`" + (f" ... (+{len(found)-10} more)" if len(found) > 10 else ""))
result_lines.append(f"**Not Found:** {len(not_found)} items")
if not_found:
result_lines.append(f"- `{', '.join(not_found[:5])}`" + (f" ... (+{len(not_found)-5} more)" if len(not_found) > 5 else ""))
parse_result = "\n".join(result_lines)
# Preview first found feature
if found:
first_gene = found[0]
status, plot, stats, _ = viewer.visualize_gene(
first_gene, pt_size, log_transform, color_scale, show_bg, bg_opacity
)
preview_status = f"Previewing: {first_gene}"
return parse_result, plot, preview_status
else:
return parse_result, None, "No features found to preview"
batch_parse_btn.click(
parse_and_preview,
inputs=[batch_variables_input, batch_point_size, batch_use_log, batch_colorscale, batch_show_background, batch_background_opacity],
outputs=[batch_parse_result, batch_preview_plot, batch_preview_status],
)
# Auto-update preview when settings change (if there's already input)
def update_preview_on_settings(variables_text, pt_size, log_transform, color_scale, show_bg, bg_opacity):
"""Update preview when batch settings change"""
found, _, _ = viewer.parse_variables_list(variables_text)
if found:
first_gene = found[0]
status, plot, stats, _ = viewer.visualize_gene(
first_gene, pt_size, log_transform, color_scale, show_bg, bg_opacity
)
return plot, f"Previewing: {first_gene}"
return gr.update(), gr.update()
batch_preview_inputs = [batch_variables_input, batch_point_size, batch_use_log, batch_colorscale, batch_show_background, batch_background_opacity]
batch_preview_outputs = [batch_preview_plot, batch_preview_status]
batch_point_size.release(update_preview_on_settings, inputs=batch_preview_inputs, outputs=batch_preview_outputs)
batch_use_log.change(update_preview_on_settings, inputs=batch_preview_inputs, outputs=batch_preview_outputs)
batch_colorscale.change(update_preview_on_settings, inputs=batch_preview_inputs, outputs=batch_preview_outputs)
batch_show_background.change(update_preview_on_settings, inputs=batch_preview_inputs, outputs=batch_preview_outputs)
batch_background_opacity.release(update_preview_on_settings, inputs=batch_preview_inputs, outputs=batch_preview_outputs)
def run_batch_visualization(variables_text, pt_size, log_transform, color_scale, show_bg, bg_opacity, progress=gr.Progress()):
"""Run batch visualization"""
status, zip_path, report, stats = viewer.batch_visualize(
variables_text, pt_size, log_transform, color_scale, show_bg, bg_opacity, progress
)
return status, zip_path, report, stats
batch_run_btn.click(
run_batch_visualization,
inputs=[batch_variables_input, batch_point_size, batch_use_log, batch_colorscale, batch_show_background, batch_background_opacity],
outputs=[batch_status, batch_download, batch_report, batch_stats],
)
# Auto-load all demo datasets on startup
def startup_load():
"""Register all built-in datasets on startup (without loading them into RAM)"""
# Skip if already registered
if viewer.data_manager.has_sources():
overview_fig = viewer.create_overview_with_background()
summary = viewer.get_adata_summary()
choices = viewer.data_manager.get_source_choices()
dataset_info = viewer.get_current_dataset_info()
selector_update = gr.update(choices=choices, value=viewer.data_manager.current_id)
return "βœ… Ready", overview_fig, summary, selector_update, dataset_info
# Register local h5ad files as sources (lazy loading)
local_files = viewer.get_local_h5ad_files()
for filename in local_files:
source_path = str(Path("data") / filename)
viewer.data_manager.add_source(
name=filename,
source_type="demo",
source_path=source_path,
adata=None # DON'T LOAD YET
)
if viewer.data_manager.has_sources():
choices = viewer.data_manager.get_source_choices()
# We don't load the first one automatically to save RAM
# But we can set it as current so the UI shows it as selected
viewer.data_manager.current_id = choices[0][1]
return (
"πŸ“‚ Datasets found. Select one to load and visualize.",
None,
"*Select a dataset to load*",
gr.update(choices=choices, value=viewer.data_manager.current_id),
"No dataset loaded"
)
return "No datasets found in data/ folder", None, "", gr.update(), ""
app.load(
startup_load,
inputs=[],
outputs=[load_status, overview_plot, dataset_summary, dataset_selector, current_dataset_display],
)
return app
if __name__ == "__main__":
app = create_interface()
app.launch()